Skip to main content

qail_pg/protocol/wire/
backend.rs

1//! BackendMessage decoder — server-to-client wire format.
2
3use super::types::*;
4
5/// Maximum backend frame length accepted by the decoder.
6///
7/// Mirrors driver-side guards to keep standalone `BackendMessage::decode`
8/// usage fail-closed against oversized frames.
9pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11fn decode_utf8(bytes: &[u8], context: &str) -> Result<String, String> {
12    std::str::from_utf8(bytes)
13        .map(str::to_string)
14        .map_err(|e| format!("{} is not valid UTF-8: {}", context, e))
15}
16
17impl BackendMessage {
18    /// Decode a message from wire bytes.
19    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
20        if buf.len() < 5 {
21            return Err("Buffer too short".to_string());
22        }
23
24        let msg_type = buf[0];
25        let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
26
27        // PG protocol: length includes itself (4 bytes), so minimum valid length is 4.
28        // Anything less is a malformed message.
29        if len < 4 {
30            return Err(format!("Invalid message length: {} (minimum is 4)", len));
31        }
32        if len > MAX_BACKEND_FRAME_LEN {
33            return Err(format!(
34                "Message too large: {} bytes (max {})",
35                len, MAX_BACKEND_FRAME_LEN
36            ));
37        }
38
39        let frame_len = len
40            .checked_add(1)
41            .ok_or_else(|| "Message length overflow".to_string())?;
42
43        if buf.len() < frame_len {
44            return Err("Incomplete message".to_string());
45        }
46
47        let payload = &buf[5..frame_len];
48
49        let message = match msg_type {
50            b'R' => Self::decode_auth(payload)?,
51            b'S' => Self::decode_parameter_status(payload)?,
52            b'K' => Self::decode_backend_key(payload)?,
53            b'v' => Self::decode_negotiate_protocol_version(payload)?,
54            b'Z' => Self::decode_ready_for_query(payload)?,
55            b'T' => Self::decode_row_description(payload)?,
56            b'D' => Self::decode_data_row(payload)?,
57            b'C' => Self::decode_command_complete(payload)?,
58            b'E' => Self::decode_error_response(payload)?,
59            b'1' => {
60                if !payload.is_empty() {
61                    return Err("ParseComplete must have empty payload".to_string());
62                }
63                BackendMessage::ParseComplete
64            }
65            b'2' => {
66                if !payload.is_empty() {
67                    return Err("BindComplete must have empty payload".to_string());
68                }
69                BackendMessage::BindComplete
70            }
71            b'3' => {
72                if !payload.is_empty() {
73                    return Err("CloseComplete must have empty payload".to_string());
74                }
75                BackendMessage::CloseComplete
76            }
77            b'n' => {
78                if !payload.is_empty() {
79                    return Err("NoData must have empty payload".to_string());
80                }
81                BackendMessage::NoData
82            }
83            b's' => {
84                if !payload.is_empty() {
85                    return Err("PortalSuspended must have empty payload".to_string());
86                }
87                BackendMessage::PortalSuspended
88            }
89            b't' => Self::decode_parameter_description(payload)?,
90            b'G' => Self::decode_copy_in_response(payload)?,
91            b'H' => Self::decode_copy_out_response(payload)?,
92            b'W' => Self::decode_copy_both_response(payload)?,
93            b'd' => BackendMessage::CopyData(payload.to_vec()),
94            b'c' => {
95                if !payload.is_empty() {
96                    return Err("CopyDone must have empty payload".to_string());
97                }
98                BackendMessage::CopyDone
99            }
100            b'A' => Self::decode_notification_response(payload)?,
101            b'I' => {
102                if !payload.is_empty() {
103                    return Err("EmptyQueryResponse must have empty payload".to_string());
104                }
105                BackendMessage::EmptyQueryResponse
106            }
107            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
108            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
109        };
110
111        Ok((message, frame_len))
112    }
113
114    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
115        if payload.len() < 4 {
116            return Err("Auth payload too short".to_string());
117        }
118        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
119        match auth_type {
120            0 => {
121                if payload.len() != 4 {
122                    return Err(format!(
123                        "AuthenticationOk invalid payload length: {}",
124                        payload.len()
125                    ));
126                }
127                Ok(BackendMessage::AuthenticationOk)
128            }
129            2 => {
130                if payload.len() != 4 {
131                    return Err(format!(
132                        "AuthenticationKerberosV5 invalid payload length: {}",
133                        payload.len()
134                    ));
135                }
136                Ok(BackendMessage::AuthenticationKerberosV5)
137            }
138            3 => {
139                if payload.len() != 4 {
140                    return Err(format!(
141                        "AuthenticationCleartextPassword invalid payload length: {}",
142                        payload.len()
143                    ));
144                }
145                Ok(BackendMessage::AuthenticationCleartextPassword)
146            }
147            5 => {
148                if payload.len() != 8 {
149                    return Err("MD5 auth payload too short (need salt)".to_string());
150                }
151                let mut salt = [0u8; 4];
152                salt.copy_from_slice(&payload[4..8]);
153                Ok(BackendMessage::AuthenticationMD5Password(salt))
154            }
155            6 => {
156                if payload.len() != 4 {
157                    return Err(format!(
158                        "AuthenticationSCMCredential invalid payload length: {}",
159                        payload.len()
160                    ));
161                }
162                Ok(BackendMessage::AuthenticationSCMCredential)
163            }
164            7 => {
165                if payload.len() != 4 {
166                    return Err(format!(
167                        "AuthenticationGSS invalid payload length: {}",
168                        payload.len()
169                    ));
170                }
171                Ok(BackendMessage::AuthenticationGSS)
172            }
173            8 => Ok(BackendMessage::AuthenticationGSSContinue(
174                payload[4..].to_vec(),
175            )),
176            9 => {
177                if payload.len() != 4 {
178                    return Err(format!(
179                        "AuthenticationSSPI invalid payload length: {}",
180                        payload.len()
181                    ));
182                }
183                Ok(BackendMessage::AuthenticationSSPI)
184            }
185            10 => {
186                // SASL - parse mechanism list
187                let mut mechanisms = Vec::new();
188                let mut pos = 4;
189                while pos < payload.len() {
190                    if payload[pos] == 0 {
191                        break; // list terminator
192                    }
193                    let end = payload[pos..]
194                        .iter()
195                        .position(|&b| b == 0)
196                        .map(|p| pos + p)
197                        .ok_or("SASL mechanism list missing null terminator")?;
198                    mechanisms.push(decode_utf8(&payload[pos..end], "SASL mechanism")?);
199                    pos = end + 1;
200                }
201                if pos >= payload.len() {
202                    return Err("SASL mechanism list missing final terminator".to_string());
203                }
204                if pos + 1 != payload.len() {
205                    return Err("SASL mechanism list has trailing bytes".to_string());
206                }
207                if mechanisms.is_empty() {
208                    return Err("SASL mechanism list is empty".to_string());
209                }
210                Ok(BackendMessage::AuthenticationSASL(mechanisms))
211            }
212            11 => {
213                // SASL Continue - server challenge
214                Ok(BackendMessage::AuthenticationSASLContinue(
215                    payload[4..].to_vec(),
216                ))
217            }
218            12 => {
219                // SASL Final - server signature
220                Ok(BackendMessage::AuthenticationSASLFinal(
221                    payload[4..].to_vec(),
222                ))
223            }
224            _ => Err(format!("Unknown auth type: {}", auth_type)),
225        }
226    }
227
228    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
229        let name_end = payload
230            .iter()
231            .position(|&b| b == 0)
232            .ok_or("ParameterStatus missing name terminator")?;
233        let value_start = name_end + 1;
234        if value_start > payload.len() {
235            return Err("ParameterStatus missing value".to_string());
236        }
237        let value_end_rel = payload[value_start..]
238            .iter()
239            .position(|&b| b == 0)
240            .ok_or("ParameterStatus missing value terminator")?;
241        let value_end = value_start + value_end_rel;
242        if value_end + 1 != payload.len() {
243            return Err("ParameterStatus has trailing bytes".to_string());
244        }
245        Ok(BackendMessage::ParameterStatus {
246            name: decode_utf8(&payload[..name_end], "ParameterStatus name")?,
247            value: decode_utf8(&payload[value_start..value_end], "ParameterStatus value")?,
248        })
249    }
250
251    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
252        if payload.len() < 8 {
253            return Err("BackendKeyData payload too short".to_string());
254        }
255        let key_len = payload.len() - 4;
256        if !(4..=256).contains(&key_len) {
257            return Err(format!(
258                "BackendKeyData invalid secret key length: {} (expected 4..=256)",
259                key_len
260            ));
261        }
262        Ok(BackendMessage::BackendKeyData {
263            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
264            secret_key: payload[4..].to_vec(),
265        })
266    }
267
268    fn decode_negotiate_protocol_version(payload: &[u8]) -> Result<Self, String> {
269        if payload.len() < 8 {
270            return Err("NegotiateProtocolVersion payload too short".to_string());
271        }
272
273        let newest_minor_supported =
274            i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
275        if newest_minor_supported < 0 {
276            return Err("NegotiateProtocolVersion newest_minor_supported is negative".to_string());
277        }
278
279        let unrecognized_count =
280            i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
281        if unrecognized_count < 0 {
282            return Err(
283                "NegotiateProtocolVersion unrecognized option count is negative".to_string(),
284            );
285        }
286        let unrecognized_count = unrecognized_count as usize;
287        let remaining = payload.len() - 8;
288        // Each option is NUL-terminated, so even an empty option consumes one byte.
289        // This bound keeps malformed counts from forcing huge pre-allocations.
290        if unrecognized_count > remaining {
291            return Err(format!(
292                "NegotiateProtocolVersion unrecognized option count {} exceeds payload capacity {}",
293                unrecognized_count, remaining
294            ));
295        }
296
297        let mut options = Vec::with_capacity(unrecognized_count);
298        let mut pos = 8usize;
299        for _ in 0..unrecognized_count {
300            if pos >= payload.len() {
301                return Err("NegotiateProtocolVersion missing option string terminator".to_string());
302            }
303            let rel_end = payload[pos..]
304                .iter()
305                .position(|&b| b == 0)
306                .ok_or("NegotiateProtocolVersion option missing null terminator")?;
307            let end = pos + rel_end;
308            options.push(decode_utf8(
309                &payload[pos..end],
310                "NegotiateProtocolVersion option",
311            )?);
312            pos = end + 1;
313        }
314
315        if pos != payload.len() {
316            return Err("NegotiateProtocolVersion has trailing bytes".to_string());
317        }
318
319        Ok(BackendMessage::NegotiateProtocolVersion {
320            newest_minor_supported,
321            unrecognized_protocol_options: options,
322        })
323    }
324
325    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
326        if payload.len() != 1 {
327            return Err("ReadyForQuery payload empty".to_string());
328        }
329        let status = match payload[0] {
330            b'I' => TransactionStatus::Idle,
331            b'T' => TransactionStatus::InBlock,
332            b'E' => TransactionStatus::Failed,
333            _ => return Err("Unknown transaction status".to_string()),
334        };
335        Ok(BackendMessage::ReadyForQuery(status))
336    }
337
338    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
339        if payload.len() < 2 {
340            return Err("RowDescription payload too short".to_string());
341        }
342
343        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
344        if raw_count < 0 {
345            return Err(format!("RowDescription invalid field count: {}", raw_count));
346        }
347        let field_count = raw_count as usize;
348        let mut fields = Vec::with_capacity(field_count);
349        let mut pos = 2;
350
351        for _ in 0..field_count {
352            // Field name (null-terminated string)
353            let name_end = payload[pos..]
354                .iter()
355                .position(|&b| b == 0)
356                .ok_or("Missing null terminator in field name")?;
357            let name = decode_utf8(&payload[pos..pos + name_end], "RowDescription field name")?;
358            pos += name_end + 1; // Skip null terminator
359
360            // Ensure we have enough bytes for the fixed fields
361            if pos + 18 > payload.len() {
362                return Err("RowDescription field truncated".to_string());
363            }
364
365            let table_oid = u32::from_be_bytes([
366                payload[pos],
367                payload[pos + 1],
368                payload[pos + 2],
369                payload[pos + 3],
370            ]);
371            pos += 4;
372
373            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
374            pos += 2;
375
376            let type_oid = u32::from_be_bytes([
377                payload[pos],
378                payload[pos + 1],
379                payload[pos + 2],
380                payload[pos + 3],
381            ]);
382            pos += 4;
383
384            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
385            pos += 2;
386
387            let type_modifier = i32::from_be_bytes([
388                payload[pos],
389                payload[pos + 1],
390                payload[pos + 2],
391                payload[pos + 3],
392            ]);
393            pos += 4;
394
395            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
396            if !(0..=1).contains(&format) {
397                return Err(format!("RowDescription invalid format code: {}", format));
398            }
399            pos += 2;
400
401            fields.push(FieldDescription {
402                name,
403                table_oid,
404                column_attr,
405                type_oid,
406                type_size,
407                type_modifier,
408                format,
409            });
410        }
411
412        if pos != payload.len() {
413            return Err("RowDescription has trailing bytes".to_string());
414        }
415
416        Ok(BackendMessage::RowDescription(fields))
417    }
418
419    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
420        if payload.len() < 2 {
421            return Err("DataRow payload too short".to_string());
422        }
423
424        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
425        if raw_count < 0 {
426            return Err(format!("DataRow invalid column count: {}", raw_count));
427        }
428        let column_count = raw_count as usize;
429        // Sanity check: each column needs at least 4 bytes (length field)
430        if column_count > (payload.len() - 2) / 4 + 1 {
431            return Err(format!(
432                "DataRow claims {} columns but payload is only {} bytes",
433                column_count,
434                payload.len()
435            ));
436        }
437        let mut columns = Vec::with_capacity(column_count);
438        let mut pos = 2;
439
440        for _ in 0..column_count {
441            if pos + 4 > payload.len() {
442                return Err("DataRow truncated".to_string());
443            }
444
445            let len = i32::from_be_bytes([
446                payload[pos],
447                payload[pos + 1],
448                payload[pos + 2],
449                payload[pos + 3],
450            ]);
451            pos += 4;
452
453            if len == -1 {
454                // NULL value
455                columns.push(None);
456            } else {
457                if len < -1 {
458                    return Err(format!("DataRow invalid column length: {}", len));
459                }
460                let len = len as usize;
461                if len > payload.len().saturating_sub(pos) {
462                    return Err("DataRow column data truncated".to_string());
463                }
464                let data = payload[pos..pos + len].to_vec();
465                pos += len;
466                columns.push(Some(data));
467            }
468        }
469
470        if pos != payload.len() {
471            return Err("DataRow has trailing bytes".to_string());
472        }
473
474        Ok(BackendMessage::DataRow(columns))
475    }
476
477    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
478        if payload.last().copied() != Some(0) {
479            return Err("CommandComplete missing null terminator".to_string());
480        }
481        let tag_bytes = &payload[..payload.len() - 1];
482        if tag_bytes.contains(&0) {
483            return Err("CommandComplete contains interior null byte".to_string());
484        }
485        let tag = decode_utf8(tag_bytes, "CommandComplete tag")?;
486        Ok(BackendMessage::CommandComplete(tag))
487    }
488
489    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
490        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
491            payload,
492        )?))
493    }
494
495    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
496        if payload.last().copied() != Some(0) {
497            return Err("ErrorResponse missing final terminator".to_string());
498        }
499        let mut fields = ErrorFields::default();
500        let mut i = 0;
501        while i < payload.len() && payload[i] != 0 {
502            let field_type = payload[i];
503            i += 1;
504            let end = payload[i..]
505                .iter()
506                .position(|&b| b == 0)
507                .map(|p| p + i)
508                .ok_or("ErrorResponse field missing null terminator")?;
509            let value = decode_utf8(&payload[i..end], "ErrorResponse field")?;
510            i = end + 1;
511
512            match field_type {
513                b'S' => fields.severity = value,
514                b'C' => fields.code = value,
515                b'M' => fields.message = value,
516                b'D' => fields.detail = Some(value),
517                b'H' => fields.hint = Some(value),
518                _ => {}
519            }
520        }
521        if i + 1 != payload.len() {
522            return Err("ErrorResponse has trailing bytes after terminator".to_string());
523        }
524        Ok(fields)
525    }
526
527    fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
528        if payload.len() < 2 {
529            return Err("ParameterDescription payload too short".to_string());
530        }
531        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
532        if raw_count < 0 {
533            return Err(format!("ParameterDescription invalid count: {}", raw_count));
534        }
535        let count = raw_count as usize;
536        let expected_len = 2 + count * 4;
537        if payload.len() < expected_len {
538            return Err(format!(
539                "ParameterDescription truncated: expected {} bytes, got {}",
540                expected_len,
541                payload.len()
542            ));
543        }
544        let mut oids = Vec::with_capacity(count);
545        let mut pos = 2;
546        for _ in 0..count {
547            oids.push(u32::from_be_bytes([
548                payload[pos],
549                payload[pos + 1],
550                payload[pos + 2],
551                payload[pos + 3],
552            ]));
553            pos += 4;
554        }
555        if pos != payload.len() {
556            return Err("ParameterDescription has trailing bytes".to_string());
557        }
558        Ok(BackendMessage::ParameterDescription(oids))
559    }
560
561    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
562        if payload.len() < 3 {
563            return Err("CopyInResponse payload too short".to_string());
564        }
565        let format = payload[0];
566        if format > 1 {
567            return Err(format!(
568                "CopyInResponse invalid overall format code: {}",
569                format
570            ));
571        }
572        let num_columns = if payload.len() >= 3 {
573            let raw = i16::from_be_bytes([payload[1], payload[2]]);
574            if raw < 0 {
575                return Err(format!(
576                    "CopyInResponse invalid negative column count: {}",
577                    raw
578                ));
579            }
580            raw as usize
581        } else {
582            0
583        };
584        let mut column_formats = Vec::with_capacity(num_columns);
585        let mut pos = 3usize;
586        for _ in 0..num_columns {
587            if pos + 2 > payload.len() {
588                return Err("CopyInResponse truncated column format list".to_string());
589            }
590            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
591            if !(0..=1).contains(&raw) {
592                return Err(format!("CopyInResponse invalid format code: {}", raw));
593            }
594            column_formats.push(raw as u8);
595            pos += 2;
596        }
597        if pos != payload.len() {
598            return Err("CopyInResponse has trailing bytes".to_string());
599        }
600        Ok(BackendMessage::CopyInResponse {
601            format,
602            column_formats,
603        })
604    }
605
606    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
607        if payload.len() < 3 {
608            return Err("CopyOutResponse payload too short".to_string());
609        }
610        let format = payload[0];
611        if format > 1 {
612            return Err(format!(
613                "CopyOutResponse invalid overall format code: {}",
614                format
615            ));
616        }
617        let num_columns = if payload.len() >= 3 {
618            let raw = i16::from_be_bytes([payload[1], payload[2]]);
619            if raw < 0 {
620                return Err(format!(
621                    "CopyOutResponse invalid negative column count: {}",
622                    raw
623                ));
624            }
625            raw as usize
626        } else {
627            0
628        };
629        let mut column_formats = Vec::with_capacity(num_columns);
630        let mut pos = 3usize;
631        for _ in 0..num_columns {
632            if pos + 2 > payload.len() {
633                return Err("CopyOutResponse truncated column format list".to_string());
634            }
635            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
636            if !(0..=1).contains(&raw) {
637                return Err(format!("CopyOutResponse invalid format code: {}", raw));
638            }
639            column_formats.push(raw as u8);
640            pos += 2;
641        }
642        if pos != payload.len() {
643            return Err("CopyOutResponse has trailing bytes".to_string());
644        }
645        Ok(BackendMessage::CopyOutResponse {
646            format,
647            column_formats,
648        })
649    }
650
651    fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
652        if payload.len() < 3 {
653            return Err("CopyBothResponse payload too short".to_string());
654        }
655        let format = payload[0];
656        if format > 1 {
657            return Err(format!(
658                "CopyBothResponse invalid overall format code: {}",
659                format
660            ));
661        }
662        let num_columns = if payload.len() >= 3 {
663            let raw = i16::from_be_bytes([payload[1], payload[2]]);
664            if raw < 0 {
665                return Err(format!(
666                    "CopyBothResponse invalid negative column count: {}",
667                    raw
668                ));
669            }
670            raw as usize
671        } else {
672            0
673        };
674        let mut column_formats = Vec::with_capacity(num_columns);
675        let mut pos = 3usize;
676        for _ in 0..num_columns {
677            if pos + 2 > payload.len() {
678                return Err("CopyBothResponse truncated column format list".to_string());
679            }
680            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
681            if !(0..=1).contains(&raw) {
682                return Err(format!("CopyBothResponse invalid format code: {}", raw));
683            }
684            column_formats.push(raw as u8);
685            pos += 2;
686        }
687        if pos != payload.len() {
688            return Err("CopyBothResponse has trailing bytes".to_string());
689        }
690        Ok(BackendMessage::CopyBothResponse {
691            format,
692            column_formats,
693        })
694    }
695
696    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
697        if payload.len() < 6 {
698            // Minimum: 4 (process_id) + 1 (channel NUL) + 1 (payload NUL)
699            return Err("NotificationResponse too short".to_string());
700        }
701        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
702
703        // Channel name (null-terminated)
704        let mut i = 4;
705        let remaining = payload.get(i..).unwrap_or(&[]);
706        let channel_end = remaining
707            .iter()
708            .position(|&b| b == 0)
709            .ok_or("NotificationResponse: missing channel null terminator")?;
710        let channel = decode_utf8(&remaining[..channel_end], "NotificationResponse channel")?;
711        i += channel_end + 1;
712
713        // Payload (null-terminated)
714        let remaining = payload.get(i..).unwrap_or(&[]);
715        let payload_end = remaining
716            .iter()
717            .position(|&b| b == 0)
718            .ok_or("NotificationResponse: missing payload null terminator")?;
719        let notification_payload =
720            decode_utf8(&remaining[..payload_end], "NotificationResponse payload")?;
721        if i + payload_end + 1 != payload.len() {
722            return Err("NotificationResponse has trailing bytes".to_string());
723        }
724
725        Ok(BackendMessage::NotificationResponse {
726            process_id,
727            channel,
728            payload: notification_payload,
729        })
730    }
731}