Skip to main content

fraiseql_wire/protocol/
decode.rs

1//! Protocol message decoding
2
3use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8/// Decode a backend message from BytesMut without cloning
9///
10/// This version decodes in-place from a mutable BytesMut buffer and returns
11/// the number of bytes consumed. The caller must advance the buffer after calling this.
12///
13/// # Returns
14/// `Ok((msg, consumed))` - Message and number of bytes consumed
15/// `Err(e)` - IO error if message is incomplete or invalid
16///
17/// # Performance
18/// This version avoids the expensive `buf.clone().freeze()` call by working directly
19/// with references, reducing allocations and copies in the hot path.
20pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
21    if data.len() < 5 {
22        return Err(io::Error::new(
23            io::ErrorKind::UnexpectedEof,
24            "incomplete message header",
25        ));
26    }
27
28    let tag = data[0];
29    let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
30
31    if data.len() < len + 1 {
32        return Err(io::Error::new(
33            io::ErrorKind::UnexpectedEof,
34            "incomplete message body",
35        ));
36    }
37
38    // Create a temporary slice starting after the tag and length
39    let msg_start = 5;
40    let msg_end = len + 1;
41    let msg_data = &data[msg_start..msg_end];
42
43    let msg = match tag {
44        tags::AUTHENTICATION => decode_authentication(msg_data)?,
45        tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
46        tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
47        tags::DATA_ROW => decode_data_row(msg_data)?,
48        tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
49        tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
50        tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
51        tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
52        tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
53        _ => {
54            return Err(io::Error::new(
55                io::ErrorKind::InvalidData,
56                format!("unknown message tag: {}", tag),
57            ))
58        }
59    };
60
61    Ok((msg, len + 1))
62}
63
64fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
65    if data.len() < 4 {
66        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
67    }
68    let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
69
70    let auth_msg = match auth_type {
71        auth::OK => AuthenticationMessage::Ok,
72        auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
73        auth::MD5_PASSWORD => {
74            if data.len() < 8 {
75                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
76            }
77            let mut salt = [0u8; 4];
78            salt.copy_from_slice(&data[4..8]);
79            AuthenticationMessage::Md5Password { salt }
80        }
81        auth::SASL => {
82            // SASL: read mechanism list (null-terminated strings)
83            let mut mechanisms = Vec::new();
84            let remaining = &data[4..];
85            let mut offset = 0;
86            loop {
87                if offset >= remaining.len() {
88                    break;
89                }
90                match remaining[offset..].iter().position(|&b| b == 0) {
91                    Some(end) => {
92                        let mechanism =
93                            String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
94                        if mechanism.is_empty() {
95                            break;
96                        }
97                        mechanisms.push(mechanism);
98                        offset += end + 1;
99                    }
100                    None => break,
101                }
102            }
103            AuthenticationMessage::Sasl { mechanisms }
104        }
105        auth::SASL_CONTINUE => {
106            // SASL continue: read remaining data as bytes
107            let data_vec = data[4..].to_vec();
108            AuthenticationMessage::SaslContinue { data: data_vec }
109        }
110        auth::SASL_FINAL => {
111            // SASL final: read remaining data as bytes
112            let data_vec = data[4..].to_vec();
113            AuthenticationMessage::SaslFinal { data: data_vec }
114        }
115        _ => {
116            return Err(io::Error::new(
117                io::ErrorKind::Unsupported,
118                format!("unsupported auth type: {}", auth_type),
119            ))
120        }
121    };
122
123    Ok(BackendMessage::Authentication(auth_msg))
124}
125
126fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
127    if data.len() < 8 {
128        return Err(io::Error::new(
129            io::ErrorKind::UnexpectedEof,
130            "backend key data",
131        ));
132    }
133    let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
134    let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
135    Ok(BackendMessage::BackendKeyData {
136        process_id,
137        secret_key,
138    })
139}
140
141fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
142    let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
143        io::Error::new(
144            io::ErrorKind::InvalidData,
145            "missing null terminator in string",
146        )
147    })?;
148    let tag = String::from_utf8_lossy(&data[..end]).to_string();
149    Ok(BackendMessage::CommandComplete(tag))
150}
151
152fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
153    if data.len() < 2 {
154        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
155    }
156    let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
157    let mut fields = Vec::with_capacity(field_count);
158    let mut offset = 2;
159
160    for _ in 0..field_count {
161        if offset + 4 > data.len() {
162            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
163        }
164        let field_len = i32::from_be_bytes([
165            data[offset],
166            data[offset + 1],
167            data[offset + 2],
168            data[offset + 3],
169        ]);
170        offset += 4;
171
172        let field = if field_len == -1 {
173            None
174        } else {
175            let len = field_len as usize;
176            if offset + len > data.len() {
177                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
178            }
179            let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
180            offset += len;
181            Some(field_bytes)
182        };
183        fields.push(field);
184    }
185
186    Ok(BackendMessage::DataRow(fields))
187}
188
189fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
190    let fields = decode_error_fields(data)?;
191    Ok(BackendMessage::ErrorResponse(fields))
192}
193
194fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
195    let fields = decode_error_fields(data)?;
196    Ok(BackendMessage::NoticeResponse(fields))
197}
198
199fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
200    let mut fields = ErrorFields::default();
201    let mut offset = 0;
202
203    loop {
204        if offset >= data.len() {
205            break;
206        }
207        let field_type = data[offset];
208        offset += 1;
209        if field_type == 0 {
210            break;
211        }
212
213        let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
214            io::Error::new(
215                io::ErrorKind::InvalidData,
216                "missing null terminator in error field",
217            )
218        })?;
219        let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
220        offset += end + 1;
221
222        match field_type {
223            b'S' => fields.severity = Some(value),
224            b'C' => fields.code = Some(value),
225            b'M' => fields.message = Some(value),
226            b'D' => fields.detail = Some(value),
227            b'H' => fields.hint = Some(value),
228            b'P' => fields.position = Some(value),
229            _ => {} // Ignore unknown fields
230        }
231    }
232
233    Ok(fields)
234}
235
236fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
237    let mut offset = 0;
238
239    let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
240        io::Error::new(
241            io::ErrorKind::InvalidData,
242            "missing null terminator in parameter name",
243        )
244    })?;
245    let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
246    offset += name_end + 1;
247
248    if offset >= data.len() {
249        return Err(io::Error::new(
250            io::ErrorKind::UnexpectedEof,
251            "parameter value",
252        ));
253    }
254    let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
255        io::Error::new(
256            io::ErrorKind::InvalidData,
257            "missing null terminator in parameter value",
258        )
259    })?;
260    let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
261
262    Ok(BackendMessage::ParameterStatus { name, value })
263}
264
265fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
266    if data.is_empty() {
267        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
268    }
269    let status = data[0];
270    Ok(BackendMessage::ReadyForQuery { status })
271}
272
273fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
274    if data.len() < 2 {
275        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
276    }
277    let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
278    let mut fields = Vec::with_capacity(field_count);
279    let mut offset = 2;
280
281    for _ in 0..field_count {
282        // Read name (null-terminated string)
283        let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
284            io::Error::new(
285                io::ErrorKind::InvalidData,
286                "missing null terminator in field name",
287            )
288        })?;
289        let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
290        offset += name_end + 1;
291
292        // Read field descriptor (26 bytes: 4+2+4+2+4+2)
293        if offset + 18 > data.len() {
294            return Err(io::Error::new(
295                io::ErrorKind::UnexpectedEof,
296                "field descriptor",
297            ));
298        }
299        let table_oid = i32::from_be_bytes([
300            data[offset],
301            data[offset + 1],
302            data[offset + 2],
303            data[offset + 3],
304        ]);
305        offset += 4;
306        let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
307        offset += 2;
308        let type_oid = i32::from_be_bytes([
309            data[offset],
310            data[offset + 1],
311            data[offset + 2],
312            data[offset + 3],
313        ]) as u32;
314        offset += 4;
315        let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
316        offset += 2;
317        let type_modifier = i32::from_be_bytes([
318            data[offset],
319            data[offset + 1],
320            data[offset + 2],
321            data[offset + 3],
322        ]);
323        offset += 4;
324        let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
325        offset += 2;
326
327        fields.push(FieldDescription {
328            name,
329            table_oid,
330            column_attr,
331            type_oid,
332            type_size,
333            type_modifier,
334            format_code,
335        });
336    }
337
338    Ok(BackendMessage::RowDescription(fields))
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_decode_authentication_ok() {
347        let mut data = BytesMut::from(
348            &[
349                b'R', // Authentication
350                0, 0, 0, 8, // Length = 8
351                0, 0, 0, 0, // Auth OK
352            ][..],
353        );
354
355        let (msg, consumed) = decode_message(&mut data).unwrap();
356        match msg {
357            BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
358            _ => panic!("expected Authentication::Ok"),
359        }
360        assert_eq!(consumed, 9); // 1 tag + 4 len + 4 auth type
361    }
362
363    #[test]
364    fn test_decode_ready_for_query() {
365        let mut data = BytesMut::from(
366            &[
367                b'Z', // ReadyForQuery
368                0, 0, 0, 5,    // Length = 5
369                b'I', // Idle
370            ][..],
371        );
372
373        let (msg, consumed) = decode_message(&mut data).unwrap();
374        match msg {
375            BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
376            _ => panic!("expected ReadyForQuery"),
377        }
378        assert_eq!(consumed, 6); // 1 tag + 4 len + 1 status
379    }
380}