Skip to main content

fraiseql_wire/protocol/
decode.rs

1//! Protocol message decoding
2
3use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8/// Decode a backend message from BytesMut without cloning
9///
10/// This version decodes in-place from a mutable BytesMut buffer and returns
11/// the number of bytes consumed. The caller must advance the buffer after calling this.
12///
13/// # Returns
14/// `Ok((msg, consumed))` - Message and number of bytes consumed
15/// `Err(e)` - IO error if message is incomplete or invalid
16///
17/// # Performance
18/// This version avoids the expensive `buf.clone().freeze()` call by working directly
19/// with references, reducing allocations and copies in the hot path.
20pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
21    if data.len() < 5 {
22        return Err(io::Error::new(
23            io::ErrorKind::UnexpectedEof,
24            "incomplete message header",
25        ));
26    }
27
28    let tag = data[0];
29    let len_i32 = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
30
31    // PostgreSQL message length includes the 4 length bytes but not the tag byte.
32    // Minimum valid length is 4 (just the length field itself).
33    if len_i32 < 4 {
34        return Err(io::Error::new(
35            io::ErrorKind::InvalidData,
36            "message length too small",
37        ));
38    }
39
40    let len = len_i32 as usize;
41
42    if data.len() < len + 1 {
43        return Err(io::Error::new(
44            io::ErrorKind::UnexpectedEof,
45            "incomplete message body",
46        ));
47    }
48
49    // Create a temporary slice starting after the tag and length
50    let msg_start = 5;
51    let msg_end = len + 1;
52    let msg_data = &data[msg_start..msg_end];
53
54    let msg = match tag {
55        tags::AUTHENTICATION => decode_authentication(msg_data)?,
56        tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
57        tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
58        tags::DATA_ROW => decode_data_row(msg_data)?,
59        tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
60        tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
61        tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
62        tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
63        tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
64        _ => {
65            return Err(io::Error::new(
66                io::ErrorKind::InvalidData,
67                format!("unknown message tag: {}", tag),
68            ))
69        }
70    };
71
72    Ok((msg, len + 1))
73}
74
75fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
76    if data.len() < 4 {
77        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
78    }
79    let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
80
81    let auth_msg = match auth_type {
82        auth::OK => AuthenticationMessage::Ok,
83        auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
84        auth::MD5_PASSWORD => {
85            if data.len() < 8 {
86                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
87            }
88            let mut salt = [0u8; 4];
89            salt.copy_from_slice(&data[4..8]);
90            AuthenticationMessage::Md5Password { salt }
91        }
92        auth::SASL => {
93            // SASL: read mechanism list (null-terminated strings)
94            let mut mechanisms = Vec::new();
95            let remaining = &data[4..];
96            let mut offset = 0;
97            loop {
98                if offset >= remaining.len() {
99                    break;
100                }
101                match remaining[offset..].iter().position(|&b| b == 0) {
102                    Some(end) => {
103                        let mechanism =
104                            String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
105                        if mechanism.is_empty() {
106                            break;
107                        }
108                        mechanisms.push(mechanism);
109                        offset += end + 1;
110                    }
111                    None => break,
112                }
113            }
114            AuthenticationMessage::Sasl { mechanisms }
115        }
116        auth::SASL_CONTINUE => {
117            // SASL continue: read remaining data as bytes
118            let data_vec = data[4..].to_vec();
119            AuthenticationMessage::SaslContinue { data: data_vec }
120        }
121        auth::SASL_FINAL => {
122            // SASL final: read remaining data as bytes
123            let data_vec = data[4..].to_vec();
124            AuthenticationMessage::SaslFinal { data: data_vec }
125        }
126        _ => {
127            return Err(io::Error::new(
128                io::ErrorKind::Unsupported,
129                format!("unsupported auth type: {}", auth_type),
130            ))
131        }
132    };
133
134    Ok(BackendMessage::Authentication(auth_msg))
135}
136
137fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
138    if data.len() < 8 {
139        return Err(io::Error::new(
140            io::ErrorKind::UnexpectedEof,
141            "backend key data",
142        ));
143    }
144    let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
145    let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
146    Ok(BackendMessage::BackendKeyData {
147        process_id,
148        secret_key,
149    })
150}
151
152fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
153    let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
154        io::Error::new(
155            io::ErrorKind::InvalidData,
156            "missing null terminator in string",
157        )
158    })?;
159    let tag = String::from_utf8_lossy(&data[..end]).to_string();
160    Ok(BackendMessage::CommandComplete(tag))
161}
162
163fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
164    if data.len() < 2 {
165        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
166    }
167    let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
168    if field_count_i16 < 0 {
169        return Err(io::Error::new(
170            io::ErrorKind::InvalidData,
171            "negative field count",
172        ));
173    }
174    let field_count = field_count_i16 as usize;
175    let mut fields = Vec::with_capacity(field_count);
176    let mut offset = 2;
177
178    for _ in 0..field_count {
179        if offset + 4 > data.len() {
180            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
181        }
182        let field_len = i32::from_be_bytes([
183            data[offset],
184            data[offset + 1],
185            data[offset + 2],
186            data[offset + 3],
187        ]);
188        offset += 4;
189
190        let field = if field_len == -1 {
191            None
192        } else if field_len < 0 {
193            return Err(io::Error::new(
194                io::ErrorKind::InvalidData,
195                "negative field length",
196            ));
197        } else {
198            let len = field_len as usize;
199            if offset + len > data.len() {
200                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
201            }
202            let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
203            offset += len;
204            Some(field_bytes)
205        };
206        fields.push(field);
207    }
208
209    Ok(BackendMessage::DataRow(fields))
210}
211
212fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
213    let fields = decode_error_fields(data)?;
214    Ok(BackendMessage::ErrorResponse(fields))
215}
216
217fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
218    let fields = decode_error_fields(data)?;
219    Ok(BackendMessage::NoticeResponse(fields))
220}
221
222fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
223    let mut fields = ErrorFields::default();
224    let mut offset = 0;
225
226    loop {
227        if offset >= data.len() {
228            break;
229        }
230        let field_type = data[offset];
231        offset += 1;
232        if field_type == 0 {
233            break;
234        }
235
236        let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
237            io::Error::new(
238                io::ErrorKind::InvalidData,
239                "missing null terminator in error field",
240            )
241        })?;
242        let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
243        offset += end + 1;
244
245        match field_type {
246            b'S' => fields.severity = Some(value),
247            b'C' => fields.code = Some(value),
248            b'M' => fields.message = Some(value),
249            b'D' => fields.detail = Some(value),
250            b'H' => fields.hint = Some(value),
251            b'P' => fields.position = Some(value),
252            _ => {} // Ignore unknown fields
253        }
254    }
255
256    Ok(fields)
257}
258
259fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
260    let mut offset = 0;
261
262    let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
263        io::Error::new(
264            io::ErrorKind::InvalidData,
265            "missing null terminator in parameter name",
266        )
267    })?;
268    let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
269    offset += name_end + 1;
270
271    if offset >= data.len() {
272        return Err(io::Error::new(
273            io::ErrorKind::UnexpectedEof,
274            "parameter value",
275        ));
276    }
277    let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
278        io::Error::new(
279            io::ErrorKind::InvalidData,
280            "missing null terminator in parameter value",
281        )
282    })?;
283    let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
284
285    Ok(BackendMessage::ParameterStatus { name, value })
286}
287
288fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
289    if data.is_empty() {
290        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
291    }
292    let status = data[0];
293    Ok(BackendMessage::ReadyForQuery { status })
294}
295
296fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
297    if data.len() < 2 {
298        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
299    }
300    let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
301    if field_count_i16 < 0 {
302        return Err(io::Error::new(
303            io::ErrorKind::InvalidData,
304            "negative field count",
305        ));
306    }
307    let field_count = field_count_i16 as usize;
308    let mut fields = Vec::with_capacity(field_count);
309    let mut offset = 2;
310
311    for _ in 0..field_count {
312        // Read name (null-terminated string)
313        let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
314            io::Error::new(
315                io::ErrorKind::InvalidData,
316                "missing null terminator in field name",
317            )
318        })?;
319        let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
320        offset += name_end + 1;
321
322        // Read field descriptor (26 bytes: 4+2+4+2+4+2)
323        if offset + 18 > data.len() {
324            return Err(io::Error::new(
325                io::ErrorKind::UnexpectedEof,
326                "field descriptor",
327            ));
328        }
329        let table_oid = i32::from_be_bytes([
330            data[offset],
331            data[offset + 1],
332            data[offset + 2],
333            data[offset + 3],
334        ]);
335        offset += 4;
336        let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
337        offset += 2;
338        let type_oid = i32::from_be_bytes([
339            data[offset],
340            data[offset + 1],
341            data[offset + 2],
342            data[offset + 3],
343        ]) as u32;
344        offset += 4;
345        let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
346        offset += 2;
347        let type_modifier = i32::from_be_bytes([
348            data[offset],
349            data[offset + 1],
350            data[offset + 2],
351            data[offset + 3],
352        ]);
353        offset += 4;
354        let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
355        offset += 2;
356
357        fields.push(FieldDescription {
358            name,
359            table_oid,
360            column_attr,
361            type_oid,
362            type_size,
363            type_modifier,
364            format_code,
365        });
366    }
367
368    Ok(BackendMessage::RowDescription(fields))
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_decode_authentication_ok() {
377        let mut data = BytesMut::from(
378            &[
379                b'R', // Authentication
380                0, 0, 0, 8, // Length = 8
381                0, 0, 0, 0, // Auth OK
382            ][..],
383        );
384
385        let (msg, consumed) = decode_message(&mut data).unwrap();
386        match msg {
387            BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
388            _ => panic!("expected Authentication::Ok"),
389        }
390        assert_eq!(consumed, 9); // 1 tag + 4 len + 4 auth type
391    }
392
393    #[test]
394    fn test_decode_ready_for_query() {
395        let mut data = BytesMut::from(
396            &[
397                b'Z', // ReadyForQuery
398                0, 0, 0, 5,    // Length = 5
399                b'I', // Idle
400            ][..],
401        );
402
403        let (msg, consumed) = decode_message(&mut data).unwrap();
404        match msg {
405            BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
406            _ => panic!("expected ReadyForQuery"),
407        }
408        assert_eq!(consumed, 6); // 1 tag + 4 len + 1 status
409    }
410}