Skip to main content

fraiseql_wire/protocol/decode/
mod.rs

1//! Protocol message decoding
2
3use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8/// Bounds-checked read cursor over a byte slice.
9///
10/// All accessors return `io::Result` so this whole file can stay panic-free
11/// under `#![deny(clippy::indexing_slicing)]`. Each method advances `offset`
12/// only on success.
13struct Cursor<'a> {
14    data: &'a [u8],
15    offset: usize,
16}
17
18impl<'a> Cursor<'a> {
19    const fn new(data: &'a [u8]) -> Self {
20        Self { data, offset: 0 }
21    }
22
23    fn remaining(&self) -> &'a [u8] {
24        // `self.offset` is monotonically advanced only by successful reads,
25        // each of which ensures the offset stays `<= self.data.len()`.
26        self.data.get(self.offset..).unwrap_or(&[])
27    }
28
29    const fn is_empty(&self) -> bool {
30        self.offset >= self.data.len()
31    }
32
33    fn read_u8(&mut self) -> io::Result<u8> {
34        let byte = *self
35            .data
36            .get(self.offset)
37            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "byte"))?;
38        self.offset += 1;
39        Ok(byte)
40    }
41
42    fn read_i16_be(&mut self) -> io::Result<i16> {
43        let bytes: [u8; 2] = self
44            .data
45            .get(self.offset..self.offset + 2)
46            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "i16"))?
47            .try_into()
48            // Reason: provably-safe — `.get(offset..offset+2)` returned a
49            // 2-byte slice, and `<[u8; 2]>::try_from(&[u8])` cannot fail on
50            // a slice of the exact length.
51            .expect("slice of length 2 always converts to [u8; 2]");
52        self.offset += 2;
53        Ok(i16::from_be_bytes(bytes))
54    }
55
56    fn read_i32_be(&mut self) -> io::Result<i32> {
57        let bytes: [u8; 4] = self
58            .data
59            .get(self.offset..self.offset + 4)
60            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "i32"))?
61            .try_into()
62            // Reason: provably-safe — slice length 4 always converts to [u8; 4].
63            .expect("slice of length 4 always converts to [u8; 4]");
64        self.offset += 4;
65        Ok(i32::from_be_bytes(bytes))
66    }
67
68    fn read_slice(&mut self, n: usize) -> io::Result<&'a [u8]> {
69        let slice = self
70            .data
71            .get(self.offset..self.offset + n)
72            .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "slice"))?;
73        self.offset += n;
74        Ok(slice)
75    }
76
77    /// Read until the next `0x00` byte (exclusive), advancing past the null.
78    /// Returns the bytes before the null terminator.
79    fn read_until_null(&mut self) -> io::Result<&'a [u8]> {
80        let tail = self.remaining();
81        let end = tail.iter().position(|&b| b == 0).ok_or_else(|| {
82            io::Error::new(
83                io::ErrorKind::InvalidData,
84                "missing null terminator in string",
85            )
86        })?;
87        let bytes = tail.get(..end).unwrap_or(&[]);
88        // Advance past `end` bytes plus the null terminator.
89        self.offset += end + 1;
90        Ok(bytes)
91    }
92
93    /// Find the next `0x00` byte in the remaining slice without advancing.
94    fn position_of_null(&self) -> Option<usize> {
95        self.remaining().iter().position(|&b| b == 0)
96    }
97}
98
99/// Maximum number of fields accepted in a single DataRow or RowDescription message.
100///
101/// PostgreSQL's protocol allows up to 1600 columns per table (hard limit enforced by
102/// the server), so 2048 is a generous cap that prevents an attacker-supplied message
103/// from triggering a huge `Vec::with_capacity` before any bounds are checked.
104pub(crate) const MAX_FIELD_COUNT: usize = 2048;
105
106/// Maximum byte length of a single error/notice field string (severity, message, etc.).
107///
108/// A 64 KiB cap is generous for any human-readable error message. Without this limit a
109/// malicious server can send a single oversized field and drive unbounded allocation
110/// in `String::from_utf8_lossy` before the string is ever stored.
111pub(crate) const MAX_ERROR_FIELD_BYTES: usize = 64 * 1024; // 64 KiB
112
113/// Maximum number of SASL mechanism names accepted in an Authentication message.
114///
115/// Real providers offer one or two mechanisms (e.g. SCRAM-SHA-256).  Capping at 32
116/// prevents a rogue server from flooding the `Vec<String>` until memory is exhausted.
117pub(crate) const MAX_SASL_MECHANISMS: usize = 32;
118
119/// Maximum byte length of a ParameterStatus name (e.g. `"server_version"`).
120///
121/// PostgreSQL parameter names are short identifiers; 256 bytes is more than enough.
122pub(crate) const MAX_PARAMETER_NAME_BYTES: usize = 256;
123
124/// Maximum byte length of a ParameterStatus value.
125///
126/// 64 KiB covers realistic values (long `TimeZone` strings, etc.) while preventing
127/// a malicious server from inflating memory with an oversized value string.
128pub(crate) const MAX_PARAMETER_VALUE_BYTES: usize = 64 * 1024; // 64 KiB
129
130/// Decode a backend message from `BytesMut` without cloning
131///
132/// This version decodes in-place from a mutable `BytesMut` buffer and returns
133/// the number of bytes consumed. The caller must advance the buffer after calling this.
134///
135/// # Errors
136///
137/// Returns `io::Error` with `UnexpectedEof` if the buffer does not contain a complete
138/// message. Returns `io::Error` with `InvalidData` if the message tag or length is invalid.
139///
140/// # Returns
141/// `Ok((msg, consumed))` - Message and number of bytes consumed
142/// `Err(e)` - IO error if message is incomplete or invalid
143///
144/// # Performance
145/// This version avoids the expensive `buf.clone().freeze()` call by working directly
146/// with references, reducing allocations and copies in the hot path.
147pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
148    if data.len() < 5 {
149        return Err(io::Error::new(
150            io::ErrorKind::UnexpectedEof,
151            "incomplete message header",
152        ));
153    }
154
155    let mut header = Cursor::new(data);
156    let tag = header.read_u8()?;
157    let len_i32 = header.read_i32_be()?;
158
159    // PostgreSQL message length includes the 4 length bytes but not the tag byte.
160    // Minimum valid length is 4 (just the length field itself).
161    if len_i32 < 4 {
162        return Err(io::Error::new(
163            io::ErrorKind::InvalidData,
164            "message length too small",
165        ));
166    }
167
168    let len = len_i32 as usize;
169
170    if data.len() < len + 1 {
171        return Err(io::Error::new(
172            io::ErrorKind::UnexpectedEof,
173            "incomplete message body",
174        ));
175    }
176
177    // Create a temporary slice starting after the tag and length
178    let msg_start = 5;
179    let msg_end = len + 1;
180    let msg_data = data
181        .get(msg_start..msg_end)
182        .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "message body slice"))?;
183
184    let msg = match tag {
185        tags::AUTHENTICATION => decode_authentication(msg_data)?,
186        tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
187        tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
188        tags::DATA_ROW => decode_data_row(msg_data)?,
189        tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
190        tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
191        tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
192        tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
193        tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
194        _ => {
195            return Err(io::Error::new(
196                io::ErrorKind::InvalidData,
197                format!("unknown message tag: {}", tag),
198            ))
199        }
200    };
201
202    Ok((msg, len + 1))
203}
204
205fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
206    let mut cur = Cursor::new(data);
207    let auth_type = cur
208        .read_i32_be()
209        .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"))?;
210
211    let auth_msg = match auth_type {
212        auth::OK => AuthenticationMessage::Ok,
213        auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
214        auth::MD5_PASSWORD => {
215            let salt_slice = cur
216                .read_slice(4)
217                .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"))?;
218            let salt: [u8; 4] = salt_slice
219                .try_into()
220                // Reason: provably-safe — `read_slice(4)` returns a 4-byte slice.
221                .expect("slice of length 4 always converts to [u8; 4]");
222            AuthenticationMessage::Md5Password { salt }
223        }
224        auth::SASL => {
225            // SASL: read mechanism list (null-terminated strings)
226            let mut mechanisms = Vec::new();
227            loop {
228                if cur.is_empty() {
229                    break;
230                }
231                let Some(end) = cur.position_of_null() else {
232                    break;
233                };
234                let mech_bytes = cur.read_slice(end).unwrap_or(&[]);
235                let mechanism = String::from_utf8_lossy(mech_bytes).to_string();
236                // Skip the null terminator we just located.
237                let _ = cur.read_u8();
238                if mechanism.is_empty() {
239                    break;
240                }
241                if mechanisms.len() >= MAX_SASL_MECHANISMS {
242                    break;
243                }
244                mechanisms.push(mechanism);
245            }
246            AuthenticationMessage::Sasl { mechanisms }
247        }
248        auth::SASL_CONTINUE => {
249            // SASL continue: read remaining data as bytes
250            let data_vec = cur.remaining().to_vec();
251            AuthenticationMessage::SaslContinue { data: data_vec }
252        }
253        auth::SASL_FINAL => {
254            // SASL final: read remaining data as bytes
255            let data_vec = cur.remaining().to_vec();
256            AuthenticationMessage::SaslFinal { data: data_vec }
257        }
258        _ => {
259            return Err(io::Error::new(
260                io::ErrorKind::Unsupported,
261                format!("unsupported auth type: {}", auth_type),
262            ))
263        }
264    };
265
266    Ok(BackendMessage::Authentication(auth_msg))
267}
268
269fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
270    let mut cur = Cursor::new(data);
271    let process_id = cur
272        .read_i32_be()
273        .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "backend key data"))?;
274    let secret_key = cur
275        .read_i32_be()
276        .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "backend key data"))?;
277    Ok(BackendMessage::BackendKeyData {
278        process_id,
279        secret_key,
280    })
281}
282
283fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
284    let mut cur = Cursor::new(data);
285    let tag_bytes = cur.read_until_null()?;
286    let tag = String::from_utf8_lossy(tag_bytes).to_string();
287    Ok(BackendMessage::CommandComplete(tag))
288}
289
290fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
291    let mut cur = Cursor::new(data);
292    let field_count_i16 = cur
293        .read_i16_be()
294        .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field count"))?;
295    if field_count_i16 < 0 {
296        return Err(io::Error::new(
297            io::ErrorKind::InvalidData,
298            "negative field count",
299        ));
300    }
301    let field_count = field_count_i16 as usize;
302    if field_count > MAX_FIELD_COUNT {
303        return Err(io::Error::new(
304            io::ErrorKind::InvalidData,
305            format!("DataRow field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
306        ));
307    }
308    let mut fields = Vec::with_capacity(field_count);
309
310    for _ in 0..field_count {
311        let field_len = cur
312            .read_i32_be()
313            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field length"))?;
314
315        let field = if field_len == -1 {
316            None
317        } else if field_len < 0 {
318            return Err(io::Error::new(
319                io::ErrorKind::InvalidData,
320                "negative field length",
321            ));
322        } else {
323            let len = field_len as usize;
324            let field_slice = cur
325                .read_slice(len)
326                .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field data"))?;
327            Some(Bytes::copy_from_slice(field_slice))
328        };
329        fields.push(field);
330    }
331
332    Ok(BackendMessage::DataRow(fields))
333}
334
335fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
336    let fields = decode_error_fields(data)?;
337    Ok(BackendMessage::ErrorResponse(fields))
338}
339
340fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
341    let fields = decode_error_fields(data)?;
342    Ok(BackendMessage::NoticeResponse(fields))
343}
344
345fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
346    let mut fields = ErrorFields::default();
347    let mut cur = Cursor::new(data);
348
349    loop {
350        if cur.is_empty() {
351            break;
352        }
353        let field_type = cur.read_u8()?;
354        if field_type == 0 {
355            break;
356        }
357
358        let end = cur.position_of_null().ok_or_else(|| {
359            io::Error::new(
360                io::ErrorKind::InvalidData,
361                "missing null terminator in error field",
362            )
363        })?;
364        if end > MAX_ERROR_FIELD_BYTES {
365            return Err(io::Error::new(
366                io::ErrorKind::InvalidData,
367                format!("Error field too large ({end} bytes, max {MAX_ERROR_FIELD_BYTES})"),
368            ));
369        }
370        let value_bytes = cur.read_slice(end).unwrap_or(&[]);
371        let value = String::from_utf8_lossy(value_bytes).to_string();
372        // Skip the null terminator.
373        let _ = cur.read_u8();
374
375        match field_type {
376            b'S' => fields.severity = Some(value),
377            b'C' => fields.code = Some(value),
378            b'M' => fields.message = Some(value),
379            b'D' => fields.detail = Some(value),
380            b'H' => fields.hint = Some(value),
381            b'P' => fields.position = Some(value),
382            _ => {} // Ignore unknown fields
383        }
384    }
385
386    Ok(fields)
387}
388
389fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
390    let mut cur = Cursor::new(data);
391
392    let name_end = cur.position_of_null().ok_or_else(|| {
393        io::Error::new(
394            io::ErrorKind::InvalidData,
395            "missing null terminator in parameter name",
396        )
397    })?;
398    if name_end > MAX_PARAMETER_NAME_BYTES {
399        return Err(io::Error::new(
400            io::ErrorKind::InvalidData,
401            format!("Parameter name too long ({name_end} bytes, max {MAX_PARAMETER_NAME_BYTES})"),
402        ));
403    }
404    let name_bytes = cur.read_slice(name_end).unwrap_or(&[]);
405    let name = String::from_utf8_lossy(name_bytes).to_string();
406    // Skip null terminator.
407    let _ = cur.read_u8();
408
409    if cur.is_empty() {
410        return Err(io::Error::new(
411            io::ErrorKind::UnexpectedEof,
412            "parameter value",
413        ));
414    }
415    let value_end = cur.position_of_null().ok_or_else(|| {
416        io::Error::new(
417            io::ErrorKind::InvalidData,
418            "missing null terminator in parameter value",
419        )
420    })?;
421    if value_end > MAX_PARAMETER_VALUE_BYTES {
422        return Err(io::Error::new(
423            io::ErrorKind::InvalidData,
424            format!(
425                "Parameter value too long ({value_end} bytes, max {MAX_PARAMETER_VALUE_BYTES})"
426            ),
427        ));
428    }
429    let value_bytes = cur.read_slice(value_end).unwrap_or(&[]);
430    let value = String::from_utf8_lossy(value_bytes).to_string();
431
432    Ok(BackendMessage::ParameterStatus { name, value })
433}
434
435fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
436    let status = *data
437        .first()
438        .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"))?;
439    Ok(BackendMessage::ReadyForQuery { status })
440}
441
442fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
443    let mut cur = Cursor::new(data);
444    let field_count_i16 = cur
445        .read_i16_be()
446        .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field count"))?;
447    if field_count_i16 < 0 {
448        return Err(io::Error::new(
449            io::ErrorKind::InvalidData,
450            "negative field count",
451        ));
452    }
453    let field_count = field_count_i16 as usize;
454    if field_count > MAX_FIELD_COUNT {
455        return Err(io::Error::new(
456            io::ErrorKind::InvalidData,
457            format!("RowDescription field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
458        ));
459    }
460    let mut fields = Vec::with_capacity(field_count);
461
462    for _ in 0..field_count {
463        // Read name (null-terminated string)
464        let name_end = cur.position_of_null().ok_or_else(|| {
465            io::Error::new(
466                io::ErrorKind::InvalidData,
467                "missing null terminator in field name",
468            )
469        })?;
470        let name_bytes = cur.read_slice(name_end).unwrap_or(&[]);
471        let name = String::from_utf8_lossy(name_bytes).to_string();
472        // Skip null terminator.
473        let _ = cur.read_u8();
474
475        // Read field descriptor (18 bytes: 4+2+4+2+4+2)
476        let table_oid = cur
477            .read_i32_be()
478            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
479        let column_attr = cur
480            .read_i16_be()
481            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
482        let type_oid = cur
483            .read_i32_be()
484            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?
485            as u32;
486        let type_size = cur
487            .read_i16_be()
488            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
489        let type_modifier = cur
490            .read_i32_be()
491            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
492        let format_code = cur
493            .read_i16_be()
494            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "field descriptor"))?;
495
496        fields.push(FieldDescription {
497            name,
498            table_oid,
499            column_attr,
500            type_oid,
501            type_size,
502            type_modifier,
503            format_code,
504        });
505    }
506
507    Ok(BackendMessage::RowDescription(fields))
508}
509
510#[cfg(test)]
511mod tests;