Skip to main content

qail_pg/protocol/wire/
backend.rs

1//! BackendMessage decoder — server-to-client wire format.
2
3use super::types::*;
4
5/// Maximum backend frame length accepted by the decoder.
6///
7/// Mirrors driver-side guards to keep standalone `BackendMessage::decode`
8/// usage fail-closed against oversized frames.
9pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11impl BackendMessage {
12    /// Decode a message from wire bytes.
13    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
14        if buf.len() < 5 {
15            return Err("Buffer too short".to_string());
16        }
17
18        let msg_type = buf[0];
19        let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
20
21        // PG protocol: length includes itself (4 bytes), so minimum valid length is 4.
22        // Anything less is a malformed message.
23        if len < 4 {
24            return Err(format!("Invalid message length: {} (minimum is 4)", len));
25        }
26        if len > MAX_BACKEND_FRAME_LEN {
27            return Err(format!(
28                "Message too large: {} bytes (max {})",
29                len, MAX_BACKEND_FRAME_LEN
30            ));
31        }
32
33        let frame_len = len
34            .checked_add(1)
35            .ok_or_else(|| "Message length overflow".to_string())?;
36
37        if buf.len() < frame_len {
38            return Err("Incomplete message".to_string());
39        }
40
41        let payload = &buf[5..frame_len];
42
43        let message = match msg_type {
44            b'R' => Self::decode_auth(payload)?,
45            b'S' => Self::decode_parameter_status(payload)?,
46            b'K' => Self::decode_backend_key(payload)?,
47            b'v' => Self::decode_negotiate_protocol_version(payload)?,
48            b'Z' => Self::decode_ready_for_query(payload)?,
49            b'T' => Self::decode_row_description(payload)?,
50            b'D' => Self::decode_data_row(payload)?,
51            b'C' => Self::decode_command_complete(payload)?,
52            b'E' => Self::decode_error_response(payload)?,
53            b'1' => {
54                if !payload.is_empty() {
55                    return Err("ParseComplete must have empty payload".to_string());
56                }
57                BackendMessage::ParseComplete
58            }
59            b'2' => {
60                if !payload.is_empty() {
61                    return Err("BindComplete must have empty payload".to_string());
62                }
63                BackendMessage::BindComplete
64            }
65            b'3' => {
66                if !payload.is_empty() {
67                    return Err("CloseComplete must have empty payload".to_string());
68                }
69                BackendMessage::CloseComplete
70            }
71            b'n' => {
72                if !payload.is_empty() {
73                    return Err("NoData must have empty payload".to_string());
74                }
75                BackendMessage::NoData
76            }
77            b's' => {
78                if !payload.is_empty() {
79                    return Err("PortalSuspended must have empty payload".to_string());
80                }
81                BackendMessage::PortalSuspended
82            }
83            b't' => Self::decode_parameter_description(payload)?,
84            b'G' => Self::decode_copy_in_response(payload)?,
85            b'H' => Self::decode_copy_out_response(payload)?,
86            b'W' => Self::decode_copy_both_response(payload)?,
87            b'd' => BackendMessage::CopyData(payload.to_vec()),
88            b'c' => {
89                if !payload.is_empty() {
90                    return Err("CopyDone must have empty payload".to_string());
91                }
92                BackendMessage::CopyDone
93            }
94            b'A' => Self::decode_notification_response(payload)?,
95            b'I' => {
96                if !payload.is_empty() {
97                    return Err("EmptyQueryResponse must have empty payload".to_string());
98                }
99                BackendMessage::EmptyQueryResponse
100            }
101            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
102            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
103        };
104
105        Ok((message, frame_len))
106    }
107
108    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
109        if payload.len() < 4 {
110            return Err("Auth payload too short".to_string());
111        }
112        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
113        match auth_type {
114            0 => {
115                if payload.len() != 4 {
116                    return Err(format!(
117                        "AuthenticationOk invalid payload length: {}",
118                        payload.len()
119                    ));
120                }
121                Ok(BackendMessage::AuthenticationOk)
122            }
123            2 => {
124                if payload.len() != 4 {
125                    return Err(format!(
126                        "AuthenticationKerberosV5 invalid payload length: {}",
127                        payload.len()
128                    ));
129                }
130                Ok(BackendMessage::AuthenticationKerberosV5)
131            }
132            3 => {
133                if payload.len() != 4 {
134                    return Err(format!(
135                        "AuthenticationCleartextPassword invalid payload length: {}",
136                        payload.len()
137                    ));
138                }
139                Ok(BackendMessage::AuthenticationCleartextPassword)
140            }
141            5 => {
142                if payload.len() != 8 {
143                    return Err("MD5 auth payload too short (need salt)".to_string());
144                }
145                let mut salt = [0u8; 4];
146                salt.copy_from_slice(&payload[4..8]);
147                Ok(BackendMessage::AuthenticationMD5Password(salt))
148            }
149            6 => {
150                if payload.len() != 4 {
151                    return Err(format!(
152                        "AuthenticationSCMCredential invalid payload length: {}",
153                        payload.len()
154                    ));
155                }
156                Ok(BackendMessage::AuthenticationSCMCredential)
157            }
158            7 => {
159                if payload.len() != 4 {
160                    return Err(format!(
161                        "AuthenticationGSS invalid payload length: {}",
162                        payload.len()
163                    ));
164                }
165                Ok(BackendMessage::AuthenticationGSS)
166            }
167            8 => Ok(BackendMessage::AuthenticationGSSContinue(
168                payload[4..].to_vec(),
169            )),
170            9 => {
171                if payload.len() != 4 {
172                    return Err(format!(
173                        "AuthenticationSSPI invalid payload length: {}",
174                        payload.len()
175                    ));
176                }
177                Ok(BackendMessage::AuthenticationSSPI)
178            }
179            10 => {
180                // SASL - parse mechanism list
181                let mut mechanisms = Vec::new();
182                let mut pos = 4;
183                while pos < payload.len() {
184                    if payload[pos] == 0 {
185                        break; // list terminator
186                    }
187                    let end = payload[pos..]
188                        .iter()
189                        .position(|&b| b == 0)
190                        .map(|p| pos + p)
191                        .ok_or("SASL mechanism list missing null terminator")?;
192                    mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
193                    pos = end + 1;
194                }
195                if pos >= payload.len() {
196                    return Err("SASL mechanism list missing final terminator".to_string());
197                }
198                if pos + 1 != payload.len() {
199                    return Err("SASL mechanism list has trailing bytes".to_string());
200                }
201                if mechanisms.is_empty() {
202                    return Err("SASL mechanism list is empty".to_string());
203                }
204                Ok(BackendMessage::AuthenticationSASL(mechanisms))
205            }
206            11 => {
207                // SASL Continue - server challenge
208                Ok(BackendMessage::AuthenticationSASLContinue(
209                    payload[4..].to_vec(),
210                ))
211            }
212            12 => {
213                // SASL Final - server signature
214                Ok(BackendMessage::AuthenticationSASLFinal(
215                    payload[4..].to_vec(),
216                ))
217            }
218            _ => Err(format!("Unknown auth type: {}", auth_type)),
219        }
220    }
221
222    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
223        let name_end = payload
224            .iter()
225            .position(|&b| b == 0)
226            .ok_or("ParameterStatus missing name terminator")?;
227        let value_start = name_end + 1;
228        if value_start > payload.len() {
229            return Err("ParameterStatus missing value".to_string());
230        }
231        let value_end_rel = payload[value_start..]
232            .iter()
233            .position(|&b| b == 0)
234            .ok_or("ParameterStatus missing value terminator")?;
235        let value_end = value_start + value_end_rel;
236        if value_end + 1 != payload.len() {
237            return Err("ParameterStatus has trailing bytes".to_string());
238        }
239        Ok(BackendMessage::ParameterStatus {
240            name: String::from_utf8_lossy(&payload[..name_end]).to_string(),
241            value: String::from_utf8_lossy(&payload[value_start..value_end]).to_string(),
242        })
243    }
244
245    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
246        if payload.len() < 8 {
247            return Err("BackendKeyData payload too short".to_string());
248        }
249        let key_len = payload.len() - 4;
250        if !(4..=256).contains(&key_len) {
251            return Err(format!(
252                "BackendKeyData invalid secret key length: {} (expected 4..=256)",
253                key_len
254            ));
255        }
256        Ok(BackendMessage::BackendKeyData {
257            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
258            secret_key: payload[4..].to_vec(),
259        })
260    }
261
262    fn decode_negotiate_protocol_version(payload: &[u8]) -> Result<Self, String> {
263        if payload.len() < 8 {
264            return Err("NegotiateProtocolVersion payload too short".to_string());
265        }
266
267        let newest_minor_supported =
268            i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
269        if newest_minor_supported < 0 {
270            return Err("NegotiateProtocolVersion newest_minor_supported is negative".to_string());
271        }
272
273        let unrecognized_count =
274            i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
275        if unrecognized_count < 0 {
276            return Err(
277                "NegotiateProtocolVersion unrecognized option count is negative".to_string(),
278            );
279        }
280        let unrecognized_count = unrecognized_count as usize;
281
282        let mut options = Vec::with_capacity(unrecognized_count);
283        let mut pos = 8usize;
284        for _ in 0..unrecognized_count {
285            if pos >= payload.len() {
286                return Err("NegotiateProtocolVersion missing option string terminator".to_string());
287            }
288            let rel_end = payload[pos..]
289                .iter()
290                .position(|&b| b == 0)
291                .ok_or("NegotiateProtocolVersion option missing null terminator")?;
292            let end = pos + rel_end;
293            options.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
294            pos = end + 1;
295        }
296
297        if pos != payload.len() {
298            return Err("NegotiateProtocolVersion has trailing bytes".to_string());
299        }
300
301        Ok(BackendMessage::NegotiateProtocolVersion {
302            newest_minor_supported,
303            unrecognized_protocol_options: options,
304        })
305    }
306
307    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
308        if payload.len() != 1 {
309            return Err("ReadyForQuery payload empty".to_string());
310        }
311        let status = match payload[0] {
312            b'I' => TransactionStatus::Idle,
313            b'T' => TransactionStatus::InBlock,
314            b'E' => TransactionStatus::Failed,
315            _ => return Err("Unknown transaction status".to_string()),
316        };
317        Ok(BackendMessage::ReadyForQuery(status))
318    }
319
320    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
321        if payload.len() < 2 {
322            return Err("RowDescription payload too short".to_string());
323        }
324
325        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
326        if raw_count < 0 {
327            return Err(format!("RowDescription invalid field count: {}", raw_count));
328        }
329        let field_count = raw_count as usize;
330        let mut fields = Vec::with_capacity(field_count);
331        let mut pos = 2;
332
333        for _ in 0..field_count {
334            // Field name (null-terminated string)
335            let name_end = payload[pos..]
336                .iter()
337                .position(|&b| b == 0)
338                .ok_or("Missing null terminator in field name")?;
339            let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
340            pos += name_end + 1; // Skip null terminator
341
342            // Ensure we have enough bytes for the fixed fields
343            if pos + 18 > payload.len() {
344                return Err("RowDescription field truncated".to_string());
345            }
346
347            let table_oid = u32::from_be_bytes([
348                payload[pos],
349                payload[pos + 1],
350                payload[pos + 2],
351                payload[pos + 3],
352            ]);
353            pos += 4;
354
355            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
356            pos += 2;
357
358            let type_oid = u32::from_be_bytes([
359                payload[pos],
360                payload[pos + 1],
361                payload[pos + 2],
362                payload[pos + 3],
363            ]);
364            pos += 4;
365
366            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
367            pos += 2;
368
369            let type_modifier = i32::from_be_bytes([
370                payload[pos],
371                payload[pos + 1],
372                payload[pos + 2],
373                payload[pos + 3],
374            ]);
375            pos += 4;
376
377            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
378            if !(0..=1).contains(&format) {
379                return Err(format!("RowDescription invalid format code: {}", format));
380            }
381            pos += 2;
382
383            fields.push(FieldDescription {
384                name,
385                table_oid,
386                column_attr,
387                type_oid,
388                type_size,
389                type_modifier,
390                format,
391            });
392        }
393
394        if pos != payload.len() {
395            return Err("RowDescription has trailing bytes".to_string());
396        }
397
398        Ok(BackendMessage::RowDescription(fields))
399    }
400
401    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
402        if payload.len() < 2 {
403            return Err("DataRow payload too short".to_string());
404        }
405
406        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
407        if raw_count < 0 {
408            return Err(format!("DataRow invalid column count: {}", raw_count));
409        }
410        let column_count = raw_count as usize;
411        // Sanity check: each column needs at least 4 bytes (length field)
412        if column_count > (payload.len() - 2) / 4 + 1 {
413            return Err(format!(
414                "DataRow claims {} columns but payload is only {} bytes",
415                column_count,
416                payload.len()
417            ));
418        }
419        let mut columns = Vec::with_capacity(column_count);
420        let mut pos = 2;
421
422        for _ in 0..column_count {
423            if pos + 4 > payload.len() {
424                return Err("DataRow truncated".to_string());
425            }
426
427            let len = i32::from_be_bytes([
428                payload[pos],
429                payload[pos + 1],
430                payload[pos + 2],
431                payload[pos + 3],
432            ]);
433            pos += 4;
434
435            if len == -1 {
436                // NULL value
437                columns.push(None);
438            } else {
439                if len < -1 {
440                    return Err(format!("DataRow invalid column length: {}", len));
441                }
442                let len = len as usize;
443                if len > payload.len().saturating_sub(pos) {
444                    return Err("DataRow column data truncated".to_string());
445                }
446                let data = payload[pos..pos + len].to_vec();
447                pos += len;
448                columns.push(Some(data));
449            }
450        }
451
452        if pos != payload.len() {
453            return Err("DataRow has trailing bytes".to_string());
454        }
455
456        Ok(BackendMessage::DataRow(columns))
457    }
458
459    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
460        if payload.last().copied() != Some(0) {
461            return Err("CommandComplete missing null terminator".to_string());
462        }
463        let tag_bytes = &payload[..payload.len() - 1];
464        if tag_bytes.contains(&0) {
465            return Err("CommandComplete contains interior null byte".to_string());
466        }
467        let tag = String::from_utf8_lossy(tag_bytes).to_string();
468        Ok(BackendMessage::CommandComplete(tag))
469    }
470
471    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
472        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
473            payload,
474        )?))
475    }
476
477    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
478        if payload.last().copied() != Some(0) {
479            return Err("ErrorResponse missing final terminator".to_string());
480        }
481        let mut fields = ErrorFields::default();
482        let mut i = 0;
483        while i < payload.len() && payload[i] != 0 {
484            let field_type = payload[i];
485            i += 1;
486            let end = payload[i..]
487                .iter()
488                .position(|&b| b == 0)
489                .map(|p| p + i)
490                .ok_or("ErrorResponse field missing null terminator")?;
491            let value = String::from_utf8_lossy(&payload[i..end]).to_string();
492            i = end + 1;
493
494            match field_type {
495                b'S' => fields.severity = value,
496                b'C' => fields.code = value,
497                b'M' => fields.message = value,
498                b'D' => fields.detail = Some(value),
499                b'H' => fields.hint = Some(value),
500                _ => {}
501            }
502        }
503        if i + 1 != payload.len() {
504            return Err("ErrorResponse has trailing bytes after terminator".to_string());
505        }
506        Ok(fields)
507    }
508
509    fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
510        if payload.len() < 2 {
511            return Err("ParameterDescription payload too short".to_string());
512        }
513        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
514        if raw_count < 0 {
515            return Err(format!("ParameterDescription invalid count: {}", raw_count));
516        }
517        let count = raw_count as usize;
518        let expected_len = 2 + count * 4;
519        if payload.len() < expected_len {
520            return Err(format!(
521                "ParameterDescription truncated: expected {} bytes, got {}",
522                expected_len,
523                payload.len()
524            ));
525        }
526        let mut oids = Vec::with_capacity(count);
527        let mut pos = 2;
528        for _ in 0..count {
529            oids.push(u32::from_be_bytes([
530                payload[pos],
531                payload[pos + 1],
532                payload[pos + 2],
533                payload[pos + 3],
534            ]));
535            pos += 4;
536        }
537        if pos != payload.len() {
538            return Err("ParameterDescription has trailing bytes".to_string());
539        }
540        Ok(BackendMessage::ParameterDescription(oids))
541    }
542
543    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
544        if payload.len() < 3 {
545            return Err("CopyInResponse payload too short".to_string());
546        }
547        let format = payload[0];
548        if format > 1 {
549            return Err(format!(
550                "CopyInResponse invalid overall format code: {}",
551                format
552            ));
553        }
554        let num_columns = if payload.len() >= 3 {
555            let raw = i16::from_be_bytes([payload[1], payload[2]]);
556            if raw < 0 {
557                return Err(format!(
558                    "CopyInResponse invalid negative column count: {}",
559                    raw
560                ));
561            }
562            raw as usize
563        } else {
564            0
565        };
566        let mut column_formats = Vec::with_capacity(num_columns);
567        let mut pos = 3usize;
568        for _ in 0..num_columns {
569            if pos + 2 > payload.len() {
570                return Err("CopyInResponse truncated column format list".to_string());
571            }
572            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
573            if !(0..=1).contains(&raw) {
574                return Err(format!("CopyInResponse invalid format code: {}", raw));
575            }
576            column_formats.push(raw as u8);
577            pos += 2;
578        }
579        if pos != payload.len() {
580            return Err("CopyInResponse has trailing bytes".to_string());
581        }
582        Ok(BackendMessage::CopyInResponse {
583            format,
584            column_formats,
585        })
586    }
587
588    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
589        if payload.len() < 3 {
590            return Err("CopyOutResponse payload too short".to_string());
591        }
592        let format = payload[0];
593        if format > 1 {
594            return Err(format!(
595                "CopyOutResponse invalid overall format code: {}",
596                format
597            ));
598        }
599        let num_columns = if payload.len() >= 3 {
600            let raw = i16::from_be_bytes([payload[1], payload[2]]);
601            if raw < 0 {
602                return Err(format!(
603                    "CopyOutResponse invalid negative column count: {}",
604                    raw
605                ));
606            }
607            raw as usize
608        } else {
609            0
610        };
611        let mut column_formats = Vec::with_capacity(num_columns);
612        let mut pos = 3usize;
613        for _ in 0..num_columns {
614            if pos + 2 > payload.len() {
615                return Err("CopyOutResponse truncated column format list".to_string());
616            }
617            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
618            if !(0..=1).contains(&raw) {
619                return Err(format!("CopyOutResponse invalid format code: {}", raw));
620            }
621            column_formats.push(raw as u8);
622            pos += 2;
623        }
624        if pos != payload.len() {
625            return Err("CopyOutResponse has trailing bytes".to_string());
626        }
627        Ok(BackendMessage::CopyOutResponse {
628            format,
629            column_formats,
630        })
631    }
632
633    fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
634        if payload.len() < 3 {
635            return Err("CopyBothResponse payload too short".to_string());
636        }
637        let format = payload[0];
638        if format > 1 {
639            return Err(format!(
640                "CopyBothResponse invalid overall format code: {}",
641                format
642            ));
643        }
644        let num_columns = if payload.len() >= 3 {
645            let raw = i16::from_be_bytes([payload[1], payload[2]]);
646            if raw < 0 {
647                return Err(format!(
648                    "CopyBothResponse invalid negative column count: {}",
649                    raw
650                ));
651            }
652            raw as usize
653        } else {
654            0
655        };
656        let mut column_formats = Vec::with_capacity(num_columns);
657        let mut pos = 3usize;
658        for _ in 0..num_columns {
659            if pos + 2 > payload.len() {
660                return Err("CopyBothResponse truncated column format list".to_string());
661            }
662            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
663            if !(0..=1).contains(&raw) {
664                return Err(format!("CopyBothResponse invalid format code: {}", raw));
665            }
666            column_formats.push(raw as u8);
667            pos += 2;
668        }
669        if pos != payload.len() {
670            return Err("CopyBothResponse has trailing bytes".to_string());
671        }
672        Ok(BackendMessage::CopyBothResponse {
673            format,
674            column_formats,
675        })
676    }
677
678    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
679        if payload.len() < 6 {
680            // Minimum: 4 (process_id) + 1 (channel NUL) + 1 (payload NUL)
681            return Err("NotificationResponse too short".to_string());
682        }
683        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
684
685        // Channel name (null-terminated)
686        let mut i = 4;
687        let remaining = payload.get(i..).unwrap_or(&[]);
688        let channel_end = remaining
689            .iter()
690            .position(|&b| b == 0)
691            .ok_or("NotificationResponse: missing channel null terminator")?;
692        let channel = String::from_utf8_lossy(&remaining[..channel_end]).to_string();
693        i += channel_end + 1;
694
695        // Payload (null-terminated)
696        let remaining = payload.get(i..).unwrap_or(&[]);
697        let payload_end = remaining
698            .iter()
699            .position(|&b| b == 0)
700            .ok_or("NotificationResponse: missing payload null terminator")?;
701        let notification_payload = String::from_utf8_lossy(&remaining[..payload_end]).to_string();
702        if i + payload_end + 1 != payload.len() {
703            return Err("NotificationResponse has trailing bytes".to_string());
704        }
705
706        Ok(BackendMessage::NotificationResponse {
707            process_id,
708            channel,
709            payload: notification_payload,
710        })
711    }
712}