Skip to main content

fraiseql_wire/protocol/
decode.rs

1//! Protocol message decoding
2
3use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8/// Maximum message length (1 GB), matching PostgreSQL's own `PQ_LARGE_MESSAGE_LIMIT`.
9///
10/// Any message whose length field exceeds this value is rejected before allocation
11/// to prevent denial-of-service via crafted length headers.
12const MAX_MESSAGE_LENGTH: usize = 1_073_741_824;
13
14/// Decode a backend message from BytesMut without cloning
15///
16/// This version decodes in-place from a mutable BytesMut buffer and returns
17/// the number of bytes consumed. The caller must advance the buffer after calling this.
18///
19/// # Returns
20/// `Ok((msg, consumed))` - Message and number of bytes consumed
21/// `Err(e)` - IO error if message is incomplete or invalid
22///
23/// # Performance
24/// This version avoids the expensive `buf.clone().freeze()` call by working directly
25/// with references, reducing allocations and copies in the hot path.
26pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
27    if data.len() < 5 {
28        return Err(io::Error::new(
29            io::ErrorKind::UnexpectedEof,
30            "incomplete message header",
31        ));
32    }
33
34    let tag = data[0];
35    let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
36
37    if len > MAX_MESSAGE_LENGTH {
38        return Err(io::Error::new(
39            io::ErrorKind::InvalidData,
40            format!(
41                "message length {} exceeds maximum allowed {}",
42                len, MAX_MESSAGE_LENGTH
43            ),
44        ));
45    }
46
47    if data.len() < len + 1 {
48        return Err(io::Error::new(
49            io::ErrorKind::UnexpectedEof,
50            "incomplete message body",
51        ));
52    }
53
54    // Create a temporary slice starting after the tag and length
55    let msg_start = 5;
56    let msg_end = len + 1;
57    let msg_data = &data[msg_start..msg_end];
58
59    let msg = match tag {
60        tags::AUTHENTICATION => decode_authentication(msg_data)?,
61        tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
62        tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
63        tags::DATA_ROW => decode_data_row(msg_data)?,
64        tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
65        tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
66        tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
67        tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
68        tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
69        _ => {
70            return Err(io::Error::new(
71                io::ErrorKind::InvalidData,
72                format!("unknown message tag: {}", tag),
73            ))
74        }
75    };
76
77    Ok((msg, len + 1))
78}
79
80fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
81    if data.len() < 4 {
82        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
83    }
84    let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
85
86    let auth_msg = match auth_type {
87        auth::OK => AuthenticationMessage::Ok,
88        auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
89        auth::MD5_PASSWORD => {
90            if data.len() < 8 {
91                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
92            }
93            let mut salt = [0u8; 4];
94            salt.copy_from_slice(&data[4..8]);
95            AuthenticationMessage::Md5Password { salt }
96        }
97        auth::SASL => {
98            // SASL: read mechanism list (null-terminated strings)
99            let mut mechanisms = Vec::new();
100            let remaining = &data[4..];
101            let mut offset = 0;
102            loop {
103                if offset >= remaining.len() {
104                    break;
105                }
106                match remaining[offset..].iter().position(|&b| b == 0) {
107                    Some(end) => {
108                        let mechanism =
109                            String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
110                        if mechanism.is_empty() {
111                            break;
112                        }
113                        mechanisms.push(mechanism);
114                        offset += end + 1;
115                    }
116                    None => break,
117                }
118            }
119            AuthenticationMessage::Sasl { mechanisms }
120        }
121        auth::SASL_CONTINUE => {
122            // SASL continue: read remaining data as bytes
123            let data_vec = data[4..].to_vec();
124            AuthenticationMessage::SaslContinue { data: data_vec }
125        }
126        auth::SASL_FINAL => {
127            // SASL final: read remaining data as bytes
128            let data_vec = data[4..].to_vec();
129            AuthenticationMessage::SaslFinal { data: data_vec }
130        }
131        _ => {
132            return Err(io::Error::new(
133                io::ErrorKind::Unsupported,
134                format!("unsupported auth type: {}", auth_type),
135            ))
136        }
137    };
138
139    Ok(BackendMessage::Authentication(auth_msg))
140}
141
142fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
143    if data.len() < 8 {
144        return Err(io::Error::new(
145            io::ErrorKind::UnexpectedEof,
146            "backend key data",
147        ));
148    }
149    let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
150    let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
151    Ok(BackendMessage::BackendKeyData {
152        process_id,
153        secret_key,
154    })
155}
156
157fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
158    let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
159        io::Error::new(
160            io::ErrorKind::InvalidData,
161            "missing null terminator in string",
162        )
163    })?;
164    let tag = String::from_utf8_lossy(&data[..end]).to_string();
165    Ok(BackendMessage::CommandComplete(tag))
166}
167
168fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
169    if data.len() < 2 {
170        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
171    }
172    let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
173    let mut fields = Vec::with_capacity(field_count);
174    let mut offset = 2;
175
176    for _ in 0..field_count {
177        if offset + 4 > data.len() {
178            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
179        }
180        let field_len = i32::from_be_bytes([
181            data[offset],
182            data[offset + 1],
183            data[offset + 2],
184            data[offset + 3],
185        ]);
186        offset += 4;
187
188        let field = if field_len == -1 {
189            None
190        } else {
191            let len = field_len as usize;
192            if offset + len > data.len() {
193                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
194            }
195            let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
196            offset += len;
197            Some(field_bytes)
198        };
199        fields.push(field);
200    }
201
202    Ok(BackendMessage::DataRow(fields))
203}
204
205fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
206    let fields = decode_error_fields(data)?;
207    Ok(BackendMessage::ErrorResponse(fields))
208}
209
210fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
211    let fields = decode_error_fields(data)?;
212    Ok(BackendMessage::NoticeResponse(fields))
213}
214
215fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
216    let mut fields = ErrorFields::default();
217    let mut offset = 0;
218
219    loop {
220        if offset >= data.len() {
221            break;
222        }
223        let field_type = data[offset];
224        offset += 1;
225        if field_type == 0 {
226            break;
227        }
228
229        let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
230            io::Error::new(
231                io::ErrorKind::InvalidData,
232                "missing null terminator in error field",
233            )
234        })?;
235        let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
236        offset += end + 1;
237
238        match field_type {
239            b'S' => fields.severity = Some(value),
240            b'C' => fields.code = Some(value),
241            b'M' => fields.message = Some(value),
242            b'D' => fields.detail = Some(value),
243            b'H' => fields.hint = Some(value),
244            b'P' => fields.position = Some(value),
245            _ => {} // Ignore unknown fields
246        }
247    }
248
249    Ok(fields)
250}
251
252fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
253    let mut offset = 0;
254
255    let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
256        io::Error::new(
257            io::ErrorKind::InvalidData,
258            "missing null terminator in parameter name",
259        )
260    })?;
261    let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
262    offset += name_end + 1;
263
264    if offset >= data.len() {
265        return Err(io::Error::new(
266            io::ErrorKind::UnexpectedEof,
267            "parameter value",
268        ));
269    }
270    let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
271        io::Error::new(
272            io::ErrorKind::InvalidData,
273            "missing null terminator in parameter value",
274        )
275    })?;
276    let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
277
278    Ok(BackendMessage::ParameterStatus { name, value })
279}
280
281fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
282    if data.is_empty() {
283        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
284    }
285    let status = data[0];
286    Ok(BackendMessage::ReadyForQuery { status })
287}
288
289fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
290    if data.len() < 2 {
291        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
292    }
293    let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
294    let mut fields = Vec::with_capacity(field_count);
295    let mut offset = 2;
296
297    for _ in 0..field_count {
298        // Read name (null-terminated string)
299        let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
300            io::Error::new(
301                io::ErrorKind::InvalidData,
302                "missing null terminator in field name",
303            )
304        })?;
305        let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
306        offset += name_end + 1;
307
308        // Read field descriptor (26 bytes: 4+2+4+2+4+2)
309        if offset + 18 > data.len() {
310            return Err(io::Error::new(
311                io::ErrorKind::UnexpectedEof,
312                "field descriptor",
313            ));
314        }
315        let table_oid = i32::from_be_bytes([
316            data[offset],
317            data[offset + 1],
318            data[offset + 2],
319            data[offset + 3],
320        ]);
321        offset += 4;
322        let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
323        offset += 2;
324        let type_oid = i32::from_be_bytes([
325            data[offset],
326            data[offset + 1],
327            data[offset + 2],
328            data[offset + 3],
329        ]) as u32;
330        offset += 4;
331        let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
332        offset += 2;
333        let type_modifier = i32::from_be_bytes([
334            data[offset],
335            data[offset + 1],
336            data[offset + 2],
337            data[offset + 3],
338        ]);
339        offset += 4;
340        let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
341        offset += 2;
342
343        fields.push(FieldDescription {
344            name,
345            table_oid,
346            column_attr,
347            type_oid,
348            type_size,
349            type_modifier,
350            format_code,
351        });
352    }
353
354    Ok(BackendMessage::RowDescription(fields))
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_decode_authentication_ok() {
363        let mut data = BytesMut::from(
364            &[
365                b'R', // Authentication
366                0, 0, 0, 8, // Length = 8
367                0, 0, 0, 0, // Auth OK
368            ][..],
369        );
370
371        let (msg, consumed) = decode_message(&mut data).unwrap();
372        match msg {
373            BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
374            _ => panic!("expected Authentication::Ok"),
375        }
376        assert_eq!(consumed, 9); // 1 tag + 4 len + 4 auth type
377    }
378
379    #[test]
380    fn test_decode_rejects_oversized_message() {
381        // Length field = MAX_MESSAGE_LENGTH + 1 (as i32 big-endian)
382        let oversized_len = (super::MAX_MESSAGE_LENGTH as i32) + 1;
383        let len_bytes = oversized_len.to_be_bytes();
384        let mut data =
385            BytesMut::from(&[b'D', len_bytes[0], len_bytes[1], len_bytes[2], len_bytes[3]][..]);
386
387        let err = decode_message(&mut data).unwrap_err();
388        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
389        assert!(err.to_string().contains("exceeds maximum"));
390    }
391
392    #[test]
393    fn test_decode_ready_for_query() {
394        let mut data = BytesMut::from(
395            &[
396                b'Z', // ReadyForQuery
397                0, 0, 0, 5,    // Length = 5
398                b'I', // Idle
399            ][..],
400        );
401
402        let (msg, consumed) = decode_message(&mut data).unwrap();
403        match msg {
404            BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
405            _ => panic!("expected ReadyForQuery"),
406        }
407        assert_eq!(consumed, 6); // 1 tag + 4 len + 1 status
408    }
409}