1use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8const MAX_MESSAGE_LENGTH: usize = 1_073_741_824;
13
14pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
27 if data.len() < 5 {
28 return Err(io::Error::new(
29 io::ErrorKind::UnexpectedEof,
30 "incomplete message header",
31 ));
32 }
33
34 let tag = data[0];
35 let len = i32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
36
37 if len > MAX_MESSAGE_LENGTH {
38 return Err(io::Error::new(
39 io::ErrorKind::InvalidData,
40 format!(
41 "message length {} exceeds maximum allowed {}",
42 len, MAX_MESSAGE_LENGTH
43 ),
44 ));
45 }
46
47 if data.len() < len + 1 {
48 return Err(io::Error::new(
49 io::ErrorKind::UnexpectedEof,
50 "incomplete message body",
51 ));
52 }
53
54 let msg_start = 5;
56 let msg_end = len + 1;
57 let msg_data = &data[msg_start..msg_end];
58
59 let msg = match tag {
60 tags::AUTHENTICATION => decode_authentication(msg_data)?,
61 tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
62 tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
63 tags::DATA_ROW => decode_data_row(msg_data)?,
64 tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
65 tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
66 tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
67 tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
68 tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
69 _ => {
70 return Err(io::Error::new(
71 io::ErrorKind::InvalidData,
72 format!("unknown message tag: {}", tag),
73 ))
74 }
75 };
76
77 Ok((msg, len + 1))
78}
79
80fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
81 if data.len() < 4 {
82 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
83 }
84 let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
85
86 let auth_msg = match auth_type {
87 auth::OK => AuthenticationMessage::Ok,
88 auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
89 auth::MD5_PASSWORD => {
90 if data.len() < 8 {
91 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
92 }
93 let mut salt = [0u8; 4];
94 salt.copy_from_slice(&data[4..8]);
95 AuthenticationMessage::Md5Password { salt }
96 }
97 auth::SASL => {
98 let mut mechanisms = Vec::new();
100 let remaining = &data[4..];
101 let mut offset = 0;
102 loop {
103 if offset >= remaining.len() {
104 break;
105 }
106 match remaining[offset..].iter().position(|&b| b == 0) {
107 Some(end) => {
108 let mechanism =
109 String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
110 if mechanism.is_empty() {
111 break;
112 }
113 mechanisms.push(mechanism);
114 offset += end + 1;
115 }
116 None => break,
117 }
118 }
119 AuthenticationMessage::Sasl { mechanisms }
120 }
121 auth::SASL_CONTINUE => {
122 let data_vec = data[4..].to_vec();
124 AuthenticationMessage::SaslContinue { data: data_vec }
125 }
126 auth::SASL_FINAL => {
127 let data_vec = data[4..].to_vec();
129 AuthenticationMessage::SaslFinal { data: data_vec }
130 }
131 _ => {
132 return Err(io::Error::new(
133 io::ErrorKind::Unsupported,
134 format!("unsupported auth type: {}", auth_type),
135 ))
136 }
137 };
138
139 Ok(BackendMessage::Authentication(auth_msg))
140}
141
142fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
143 if data.len() < 8 {
144 return Err(io::Error::new(
145 io::ErrorKind::UnexpectedEof,
146 "backend key data",
147 ));
148 }
149 let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
150 let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
151 Ok(BackendMessage::BackendKeyData {
152 process_id,
153 secret_key,
154 })
155}
156
157fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
158 let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
159 io::Error::new(
160 io::ErrorKind::InvalidData,
161 "missing null terminator in string",
162 )
163 })?;
164 let tag = String::from_utf8_lossy(&data[..end]).to_string();
165 Ok(BackendMessage::CommandComplete(tag))
166}
167
168fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
169 if data.len() < 2 {
170 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
171 }
172 let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
173 let mut fields = Vec::with_capacity(field_count);
174 let mut offset = 2;
175
176 for _ in 0..field_count {
177 if offset + 4 > data.len() {
178 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
179 }
180 let field_len = i32::from_be_bytes([
181 data[offset],
182 data[offset + 1],
183 data[offset + 2],
184 data[offset + 3],
185 ]);
186 offset += 4;
187
188 let field = if field_len == -1 {
189 None
190 } else {
191 let len = field_len as usize;
192 if offset + len > data.len() {
193 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
194 }
195 let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
196 offset += len;
197 Some(field_bytes)
198 };
199 fields.push(field);
200 }
201
202 Ok(BackendMessage::DataRow(fields))
203}
204
205fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
206 let fields = decode_error_fields(data)?;
207 Ok(BackendMessage::ErrorResponse(fields))
208}
209
210fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
211 let fields = decode_error_fields(data)?;
212 Ok(BackendMessage::NoticeResponse(fields))
213}
214
215fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
216 let mut fields = ErrorFields::default();
217 let mut offset = 0;
218
219 loop {
220 if offset >= data.len() {
221 break;
222 }
223 let field_type = data[offset];
224 offset += 1;
225 if field_type == 0 {
226 break;
227 }
228
229 let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
230 io::Error::new(
231 io::ErrorKind::InvalidData,
232 "missing null terminator in error field",
233 )
234 })?;
235 let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
236 offset += end + 1;
237
238 match field_type {
239 b'S' => fields.severity = Some(value),
240 b'C' => fields.code = Some(value),
241 b'M' => fields.message = Some(value),
242 b'D' => fields.detail = Some(value),
243 b'H' => fields.hint = Some(value),
244 b'P' => fields.position = Some(value),
245 _ => {} }
247 }
248
249 Ok(fields)
250}
251
252fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
253 let mut offset = 0;
254
255 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
256 io::Error::new(
257 io::ErrorKind::InvalidData,
258 "missing null terminator in parameter name",
259 )
260 })?;
261 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
262 offset += name_end + 1;
263
264 if offset >= data.len() {
265 return Err(io::Error::new(
266 io::ErrorKind::UnexpectedEof,
267 "parameter value",
268 ));
269 }
270 let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
271 io::Error::new(
272 io::ErrorKind::InvalidData,
273 "missing null terminator in parameter value",
274 )
275 })?;
276 let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
277
278 Ok(BackendMessage::ParameterStatus { name, value })
279}
280
281fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
282 if data.is_empty() {
283 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
284 }
285 let status = data[0];
286 Ok(BackendMessage::ReadyForQuery { status })
287}
288
289fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
290 if data.len() < 2 {
291 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
292 }
293 let field_count = i16::from_be_bytes([data[0], data[1]]) as usize;
294 let mut fields = Vec::with_capacity(field_count);
295 let mut offset = 2;
296
297 for _ in 0..field_count {
298 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
300 io::Error::new(
301 io::ErrorKind::InvalidData,
302 "missing null terminator in field name",
303 )
304 })?;
305 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
306 offset += name_end + 1;
307
308 if offset + 18 > data.len() {
310 return Err(io::Error::new(
311 io::ErrorKind::UnexpectedEof,
312 "field descriptor",
313 ));
314 }
315 let table_oid = i32::from_be_bytes([
316 data[offset],
317 data[offset + 1],
318 data[offset + 2],
319 data[offset + 3],
320 ]);
321 offset += 4;
322 let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
323 offset += 2;
324 let type_oid = i32::from_be_bytes([
325 data[offset],
326 data[offset + 1],
327 data[offset + 2],
328 data[offset + 3],
329 ]) as u32;
330 offset += 4;
331 let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
332 offset += 2;
333 let type_modifier = i32::from_be_bytes([
334 data[offset],
335 data[offset + 1],
336 data[offset + 2],
337 data[offset + 3],
338 ]);
339 offset += 4;
340 let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
341 offset += 2;
342
343 fields.push(FieldDescription {
344 name,
345 table_oid,
346 column_attr,
347 type_oid,
348 type_size,
349 type_modifier,
350 format_code,
351 });
352 }
353
354 Ok(BackendMessage::RowDescription(fields))
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_decode_authentication_ok() {
363 let mut data = BytesMut::from(
364 &[
365 b'R', 0, 0, 0, 8, 0, 0, 0, 0, ][..],
369 );
370
371 let (msg, consumed) = decode_message(&mut data).unwrap();
372 match msg {
373 BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
374 _ => panic!("expected Authentication::Ok"),
375 }
376 assert_eq!(consumed, 9); }
378
379 #[test]
380 fn test_decode_rejects_oversized_message() {
381 let oversized_len = (super::MAX_MESSAGE_LENGTH as i32) + 1;
383 let len_bytes = oversized_len.to_be_bytes();
384 let mut data =
385 BytesMut::from(&[b'D', len_bytes[0], len_bytes[1], len_bytes[2], len_bytes[3]][..]);
386
387 let err = decode_message(&mut data).unwrap_err();
388 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
389 assert!(err.to_string().contains("exceeds maximum"));
390 }
391
392 #[test]
393 fn test_decode_ready_for_query() {
394 let mut data = BytesMut::from(
395 &[
396 b'Z', 0, 0, 0, 5, b'I', ][..],
400 );
401
402 let (msg, consumed) = decode_message(&mut data).unwrap();
403 match msg {
404 BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
405 _ => panic!("expected ReadyForQuery"),
406 }
407 assert_eq!(consumed, 6); }
409}