1use super::constants::{auth, tags};
4use super::message::{AuthenticationMessage, BackendMessage, ErrorFields, FieldDescription};
5use bytes::{Bytes, BytesMut};
6use std::io;
7
8const MAX_FIELD_COUNT: usize = 2048;
14
15const MAX_ERROR_FIELD_BYTES: usize = 64 * 1024; const MAX_SASL_MECHANISMS: usize = 32;
27
28const MAX_PARAMETER_NAME_BYTES: usize = 256;
32
33const MAX_PARAMETER_VALUE_BYTES: usize = 64 * 1024; pub fn decode_message(data: &mut BytesMut) -> io::Result<(BackendMessage, usize)> {
57 if data.len() < 5 {
58 return Err(io::Error::new(
59 io::ErrorKind::UnexpectedEof,
60 "incomplete message header",
61 ));
62 }
63
64 let tag = data[0];
65 let len_i32 = i32::from_be_bytes([data[1], data[2], data[3], data[4]]);
66
67 if len_i32 < 4 {
70 return Err(io::Error::new(
71 io::ErrorKind::InvalidData,
72 "message length too small",
73 ));
74 }
75
76 let len = len_i32 as usize;
77
78 if data.len() < len + 1 {
79 return Err(io::Error::new(
80 io::ErrorKind::UnexpectedEof,
81 "incomplete message body",
82 ));
83 }
84
85 let msg_start = 5;
87 let msg_end = len + 1;
88 let msg_data = &data[msg_start..msg_end];
89
90 let msg = match tag {
91 tags::AUTHENTICATION => decode_authentication(msg_data)?,
92 tags::BACKEND_KEY_DATA => decode_backend_key_data(msg_data)?,
93 tags::COMMAND_COMPLETE => decode_command_complete(msg_data)?,
94 tags::DATA_ROW => decode_data_row(msg_data)?,
95 tags::ERROR_RESPONSE => decode_error_response(msg_data)?,
96 tags::NOTICE_RESPONSE => decode_notice_response(msg_data)?,
97 tags::PARAMETER_STATUS => decode_parameter_status(msg_data)?,
98 tags::READY_FOR_QUERY => decode_ready_for_query(msg_data)?,
99 tags::ROW_DESCRIPTION => decode_row_description(msg_data)?,
100 _ => {
101 return Err(io::Error::new(
102 io::ErrorKind::InvalidData,
103 format!("unknown message tag: {}", tag),
104 ))
105 }
106 };
107
108 Ok((msg, len + 1))
109}
110
111fn decode_authentication(data: &[u8]) -> io::Result<BackendMessage> {
112 if data.len() < 4 {
113 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "auth type"));
114 }
115 let auth_type = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
116
117 let auth_msg = match auth_type {
118 auth::OK => AuthenticationMessage::Ok,
119 auth::CLEARTEXT_PASSWORD => AuthenticationMessage::CleartextPassword,
120 auth::MD5_PASSWORD => {
121 if data.len() < 8 {
122 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "salt data"));
123 }
124 let mut salt = [0u8; 4];
125 salt.copy_from_slice(&data[4..8]);
126 AuthenticationMessage::Md5Password { salt }
127 }
128 auth::SASL => {
129 let mut mechanisms = Vec::new();
131 let remaining = &data[4..];
132 let mut offset = 0;
133 loop {
134 if offset >= remaining.len() {
135 break;
136 }
137 match remaining[offset..].iter().position(|&b| b == 0) {
138 Some(end) => {
139 let mechanism =
140 String::from_utf8_lossy(&remaining[offset..offset + end]).to_string();
141 if mechanism.is_empty() {
142 break;
143 }
144 if mechanisms.len() >= MAX_SASL_MECHANISMS {
145 break;
146 }
147 mechanisms.push(mechanism);
148 offset += end + 1;
149 }
150 None => break,
151 }
152 }
153 AuthenticationMessage::Sasl { mechanisms }
154 }
155 auth::SASL_CONTINUE => {
156 let data_vec = data[4..].to_vec();
158 AuthenticationMessage::SaslContinue { data: data_vec }
159 }
160 auth::SASL_FINAL => {
161 let data_vec = data[4..].to_vec();
163 AuthenticationMessage::SaslFinal { data: data_vec }
164 }
165 _ => {
166 return Err(io::Error::new(
167 io::ErrorKind::Unsupported,
168 format!("unsupported auth type: {}", auth_type),
169 ))
170 }
171 };
172
173 Ok(BackendMessage::Authentication(auth_msg))
174}
175
176fn decode_backend_key_data(data: &[u8]) -> io::Result<BackendMessage> {
177 if data.len() < 8 {
178 return Err(io::Error::new(
179 io::ErrorKind::UnexpectedEof,
180 "backend key data",
181 ));
182 }
183 let process_id = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
184 let secret_key = i32::from_be_bytes([data[4], data[5], data[6], data[7]]);
185 Ok(BackendMessage::BackendKeyData {
186 process_id,
187 secret_key,
188 })
189}
190
191fn decode_command_complete(data: &[u8]) -> io::Result<BackendMessage> {
192 let end = data.iter().position(|&b| b == 0).ok_or_else(|| {
193 io::Error::new(
194 io::ErrorKind::InvalidData,
195 "missing null terminator in string",
196 )
197 })?;
198 let tag = String::from_utf8_lossy(&data[..end]).to_string();
199 Ok(BackendMessage::CommandComplete(tag))
200}
201
202fn decode_data_row(data: &[u8]) -> io::Result<BackendMessage> {
203 if data.len() < 2 {
204 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
205 }
206 let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
207 if field_count_i16 < 0 {
208 return Err(io::Error::new(
209 io::ErrorKind::InvalidData,
210 "negative field count",
211 ));
212 }
213 let field_count = field_count_i16 as usize;
214 if field_count > MAX_FIELD_COUNT {
215 return Err(io::Error::new(
216 io::ErrorKind::InvalidData,
217 format!("DataRow field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
218 ));
219 }
220 let mut fields = Vec::with_capacity(field_count);
221 let mut offset = 2;
222
223 for _ in 0..field_count {
224 if offset + 4 > data.len() {
225 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field length"));
226 }
227 let field_len = i32::from_be_bytes([
228 data[offset],
229 data[offset + 1],
230 data[offset + 2],
231 data[offset + 3],
232 ]);
233 offset += 4;
234
235 let field = if field_len == -1 {
236 None
237 } else if field_len < 0 {
238 return Err(io::Error::new(
239 io::ErrorKind::InvalidData,
240 "negative field length",
241 ));
242 } else {
243 let len = field_len as usize;
244 if offset + len > data.len() {
245 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field data"));
246 }
247 let field_bytes = Bytes::copy_from_slice(&data[offset..offset + len]);
248 offset += len;
249 Some(field_bytes)
250 };
251 fields.push(field);
252 }
253
254 Ok(BackendMessage::DataRow(fields))
255}
256
257fn decode_error_response(data: &[u8]) -> io::Result<BackendMessage> {
258 let fields = decode_error_fields(data)?;
259 Ok(BackendMessage::ErrorResponse(fields))
260}
261
262fn decode_notice_response(data: &[u8]) -> io::Result<BackendMessage> {
263 let fields = decode_error_fields(data)?;
264 Ok(BackendMessage::NoticeResponse(fields))
265}
266
267fn decode_error_fields(data: &[u8]) -> io::Result<ErrorFields> {
268 let mut fields = ErrorFields::default();
269 let mut offset = 0;
270
271 loop {
272 if offset >= data.len() {
273 break;
274 }
275 let field_type = data[offset];
276 offset += 1;
277 if field_type == 0 {
278 break;
279 }
280
281 let end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
282 io::Error::new(
283 io::ErrorKind::InvalidData,
284 "missing null terminator in error field",
285 )
286 })?;
287 if end > MAX_ERROR_FIELD_BYTES {
288 return Err(io::Error::new(
289 io::ErrorKind::InvalidData,
290 format!("Error field too large ({end} bytes, max {MAX_ERROR_FIELD_BYTES})"),
291 ));
292 }
293 let value = String::from_utf8_lossy(&data[offset..offset + end]).to_string();
294 offset += end + 1;
295
296 match field_type {
297 b'S' => fields.severity = Some(value),
298 b'C' => fields.code = Some(value),
299 b'M' => fields.message = Some(value),
300 b'D' => fields.detail = Some(value),
301 b'H' => fields.hint = Some(value),
302 b'P' => fields.position = Some(value),
303 _ => {} }
305 }
306
307 Ok(fields)
308}
309
310fn decode_parameter_status(data: &[u8]) -> io::Result<BackendMessage> {
311 let mut offset = 0;
312
313 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
314 io::Error::new(
315 io::ErrorKind::InvalidData,
316 "missing null terminator in parameter name",
317 )
318 })?;
319 if name_end > MAX_PARAMETER_NAME_BYTES {
320 return Err(io::Error::new(
321 io::ErrorKind::InvalidData,
322 format!("Parameter name too long ({name_end} bytes, max {MAX_PARAMETER_NAME_BYTES})"),
323 ));
324 }
325 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
326 offset += name_end + 1;
327
328 if offset >= data.len() {
329 return Err(io::Error::new(
330 io::ErrorKind::UnexpectedEof,
331 "parameter value",
332 ));
333 }
334 let value_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
335 io::Error::new(
336 io::ErrorKind::InvalidData,
337 "missing null terminator in parameter value",
338 )
339 })?;
340 if value_end > MAX_PARAMETER_VALUE_BYTES {
341 return Err(io::Error::new(
342 io::ErrorKind::InvalidData,
343 format!(
344 "Parameter value too long ({value_end} bytes, max {MAX_PARAMETER_VALUE_BYTES})"
345 ),
346 ));
347 }
348 let value = String::from_utf8_lossy(&data[offset..offset + value_end]).to_string();
349
350 Ok(BackendMessage::ParameterStatus { name, value })
351}
352
353fn decode_ready_for_query(data: &[u8]) -> io::Result<BackendMessage> {
354 if data.is_empty() {
355 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "status byte"));
356 }
357 let status = data[0];
358 Ok(BackendMessage::ReadyForQuery { status })
359}
360
361fn decode_row_description(data: &[u8]) -> io::Result<BackendMessage> {
362 if data.len() < 2 {
363 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "field count"));
364 }
365 let field_count_i16 = i16::from_be_bytes([data[0], data[1]]);
366 if field_count_i16 < 0 {
367 return Err(io::Error::new(
368 io::ErrorKind::InvalidData,
369 "negative field count",
370 ));
371 }
372 let field_count = field_count_i16 as usize;
373 if field_count > MAX_FIELD_COUNT {
374 return Err(io::Error::new(
375 io::ErrorKind::InvalidData,
376 format!("RowDescription field count {field_count} exceeds maximum {MAX_FIELD_COUNT}"),
377 ));
378 }
379 let mut fields = Vec::with_capacity(field_count);
380 let mut offset = 2;
381
382 for _ in 0..field_count {
383 let name_end = data[offset..].iter().position(|&b| b == 0).ok_or_else(|| {
385 io::Error::new(
386 io::ErrorKind::InvalidData,
387 "missing null terminator in field name",
388 )
389 })?;
390 let name = String::from_utf8_lossy(&data[offset..offset + name_end]).to_string();
391 offset += name_end + 1;
392
393 if offset + 18 > data.len() {
395 return Err(io::Error::new(
396 io::ErrorKind::UnexpectedEof,
397 "field descriptor",
398 ));
399 }
400 let table_oid = i32::from_be_bytes([
401 data[offset],
402 data[offset + 1],
403 data[offset + 2],
404 data[offset + 3],
405 ]);
406 offset += 4;
407 let column_attr = i16::from_be_bytes([data[offset], data[offset + 1]]);
408 offset += 2;
409 let type_oid = i32::from_be_bytes([
410 data[offset],
411 data[offset + 1],
412 data[offset + 2],
413 data[offset + 3],
414 ]) as u32;
415 offset += 4;
416 let type_size = i16::from_be_bytes([data[offset], data[offset + 1]]);
417 offset += 2;
418 let type_modifier = i32::from_be_bytes([
419 data[offset],
420 data[offset + 1],
421 data[offset + 2],
422 data[offset + 3],
423 ]);
424 offset += 4;
425 let format_code = i16::from_be_bytes([data[offset], data[offset + 1]]);
426 offset += 2;
427
428 fields.push(FieldDescription {
429 name,
430 table_oid,
431 column_attr,
432 type_oid,
433 type_size,
434 type_modifier,
435 format_code,
436 });
437 }
438
439 Ok(BackendMessage::RowDescription(fields))
440}
441
442#[cfg(test)]
443mod tests {
444 #![allow(clippy::unwrap_used)] use super::*;
446
447 #[test]
448 fn test_decode_authentication_ok() {
449 let mut data = BytesMut::from(
450 &[
451 b'R', 0, 0, 0, 8, 0, 0, 0, 0, ][..],
455 );
456
457 let (msg, consumed) = decode_message(&mut data).unwrap();
458 match msg {
459 BackendMessage::Authentication(AuthenticationMessage::Ok) => {}
460 _ => panic!("expected Authentication::Ok"),
461 }
462 assert_eq!(consumed, 9); }
464
465 #[test]
466 fn test_decode_ready_for_query() {
467 let mut data = BytesMut::from(
468 &[
469 b'Z', 0, 0, 0, 5, b'I', ][..],
473 );
474
475 let (msg, consumed) = decode_message(&mut data).unwrap();
476 match msg {
477 BackendMessage::ReadyForQuery { status } => assert_eq!(status, b'I'),
478 _ => panic!("expected ReadyForQuery"),
479 }
480 assert_eq!(consumed, 6); }
482
483 fn make_data_row_with_count(count: i16) -> BytesMut {
486 let body_len: u32 = 2 + 4 * u32::from(count.unsigned_abs());
489 let mut buf = BytesMut::new();
490 buf.extend_from_slice(b"D");
491 buf.extend_from_slice(&(body_len + 4).to_be_bytes()); buf.extend_from_slice(&count.to_be_bytes());
493 for _ in 0..count {
494 buf.extend_from_slice(&(-1i32).to_be_bytes()); }
496 buf
497 }
498
499 fn make_row_description_with_count(count: i16) -> BytesMut {
500 let body_len: u32 = 2 + 19 * u32::from(count.unsigned_abs());
503 let mut buf = BytesMut::new();
504 buf.extend_from_slice(b"T");
505 buf.extend_from_slice(&(body_len + 4).to_be_bytes());
506 buf.extend_from_slice(&count.to_be_bytes());
507 for _ in 0..count {
508 buf.extend_from_slice(&[0u8]); buf.extend_from_slice(&[0u8; 18]); }
511 buf
512 }
513
514 #[test]
515 fn test_data_row_zero_fields_accepted() {
516 let mut buf = make_data_row_with_count(0);
517 let result = decode_message(&mut buf);
518 assert!(result.is_ok(), "zero-field DataRow must be accepted");
519 let (msg, _) = result.unwrap();
520 assert!(matches!(msg, BackendMessage::DataRow(fields) if fields.is_empty()));
521 }
522
523 #[test]
524 fn test_data_row_field_count_exceeds_max_is_rejected() {
525 let count: i16 = (MAX_FIELD_COUNT + 1) as i16; let mut buf = BytesMut::new();
529 buf.extend_from_slice(b"D");
530 buf.extend_from_slice(&10u32.to_be_bytes());
532 buf.extend_from_slice(&count.to_be_bytes());
533 buf.extend_from_slice(&[0u8; 4]);
534
535 let result = decode_message(&mut buf);
536 assert!(result.is_err(), "DataRow with 2049 fields must be rejected");
537 let err = result.unwrap_err();
538 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
539 let msg = err.to_string();
540 assert!(msg.contains("2048"), "error must mention the limit: {msg}");
541 }
542
543 #[test]
544 fn test_row_description_field_count_exceeds_max_is_rejected() {
545 let count: i16 = (MAX_FIELD_COUNT + 1) as i16; let mut buf = BytesMut::new();
547 buf.extend_from_slice(b"T");
548 buf.extend_from_slice(&10u32.to_be_bytes());
549 buf.extend_from_slice(&count.to_be_bytes());
550 buf.extend_from_slice(&[0u8; 4]);
551
552 let result = decode_message(&mut buf);
553 assert!(
554 result.is_err(),
555 "RowDescription with 2049 fields must be rejected"
556 );
557 let err = result.unwrap_err();
558 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
559 let msg = err.to_string();
560 assert!(msg.contains("2048"), "error must mention the limit: {msg}");
561 }
562
563 #[test]
564 fn test_row_description_small_field_count_accepted() {
565 let mut buf = make_row_description_with_count(3);
566 let result = decode_message(&mut buf);
567 assert!(
568 result.is_ok(),
569 "3-field RowDescription must be accepted: {result:?}"
570 );
571 let (msg, _) = result.unwrap();
572 assert!(matches!(msg, BackendMessage::RowDescription(fields) if fields.len() == 3));
573 }
574
575 fn make_error_response(field_type: u8, field_value: &[u8]) -> BytesMut {
578 let body_len = 1 + field_value.len() + 1 + 1; let mut buf = BytesMut::new();
582 buf.extend_from_slice(b"E");
583 buf.extend_from_slice(&(body_len as u32 + 4).to_be_bytes());
584 buf.extend_from_slice(&[field_type]);
585 buf.extend_from_slice(field_value);
586 buf.extend_from_slice(&[0]); buf.extend_from_slice(&[0]); buf
589 }
590
591 #[test]
592 fn error_field_within_limit_is_accepted() {
593 let value = vec![b'x'; 1024]; let mut buf = make_error_response(b'M', &value);
595 let result = decode_message(&mut buf);
596 assert!(
597 result.is_ok(),
598 "small error field must be accepted: {result:?}"
599 );
600 }
601
602 #[test]
603 fn error_field_exceeding_limit_is_rejected() {
604 let value = vec![b'x'; MAX_ERROR_FIELD_BYTES + 1]; let mut buf = make_error_response(b'M', &value);
606 let result = decode_message(&mut buf);
607 assert!(result.is_err(), "oversized error field must be rejected");
608 let err = result.unwrap_err();
609 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
610 let msg = err.to_string();
611 assert!(
612 msg.contains("too large") || msg.contains("65536"),
613 "error must mention size limit: {msg}"
614 );
615 }
616
617 fn make_sasl_auth(mechanisms: &[&str]) -> BytesMut {
620 let mut mechanism_bytes: Vec<u8> = Vec::new();
622 for m in mechanisms {
623 mechanism_bytes.extend_from_slice(m.as_bytes());
624 mechanism_bytes.push(0);
625 }
626 mechanism_bytes.push(0); let body_len = 4 + mechanism_bytes.len(); let mut buf = BytesMut::new();
629 buf.extend_from_slice(b"R");
630 buf.extend_from_slice(&(body_len as u32 + 4).to_be_bytes());
631 buf.extend_from_slice(&10u32.to_be_bytes()); buf.extend_from_slice(&mechanism_bytes);
633 buf
634 }
635
636 #[test]
637 fn sasl_mechanisms_within_limit_are_accepted() {
638 let mechanisms: Vec<&str> = (0..MAX_SASL_MECHANISMS).map(|_| "SCRAM-SHA-256").collect();
639 let mut buf = make_sasl_auth(&mechanisms);
640 let result = decode_message(&mut buf);
641 assert!(
642 result.is_ok(),
643 "SASL with {MAX_SASL_MECHANISMS} mechanisms must be accepted"
644 );
645 }
646
647 #[test]
648 fn sasl_mechanisms_exceeding_limit_are_truncated_not_rejected() {
649 let mechanisms: Vec<&str> = (0..MAX_SASL_MECHANISMS + 5)
652 .map(|_| "SCRAM-SHA-256")
653 .collect();
654 let mut buf = make_sasl_auth(&mechanisms);
655 let result = decode_message(&mut buf);
656 assert!(
657 result.is_ok(),
658 "SASL with excess mechanisms must still parse successfully"
659 );
660 if let Ok((
661 BackendMessage::Authentication(AuthenticationMessage::Sasl { mechanisms: parsed }),
662 _,
663 )) = result
664 {
665 assert!(
666 parsed.len() <= MAX_SASL_MECHANISMS,
667 "parsed mechanisms must not exceed cap: {} > {MAX_SASL_MECHANISMS}",
668 parsed.len()
669 );
670 }
671 }
672
673 fn make_parameter_status(name: &[u8], value: &[u8]) -> BytesMut {
676 let body_len = name.len() + 1 + value.len() + 1; let mut buf = BytesMut::new();
678 buf.extend_from_slice(b"S");
679 buf.extend_from_slice(&(body_len as u32 + 4).to_be_bytes());
680 buf.extend_from_slice(name);
681 buf.extend_from_slice(&[0]);
682 buf.extend_from_slice(value);
683 buf.extend_from_slice(&[0]);
684 buf
685 }
686
687 #[test]
688 fn parameter_status_normal_is_accepted() {
689 let mut buf = make_parameter_status(b"server_version", b"16.0");
690 let result = decode_message(&mut buf);
691 assert!(
692 result.is_ok(),
693 "normal ParameterStatus must be accepted: {result:?}"
694 );
695 }
696
697 #[test]
698 fn parameter_name_exceeding_limit_is_rejected() {
699 let long_name = vec![b'a'; MAX_PARAMETER_NAME_BYTES + 1];
700 let mut buf = make_parameter_status(&long_name, b"value");
701 let result = decode_message(&mut buf);
702 assert!(result.is_err(), "oversized parameter name must be rejected");
703 let msg = result.unwrap_err().to_string();
704 assert!(
705 msg.contains("too long") || msg.contains("256"),
706 "error must mention the name limit: {msg}"
707 );
708 }
709
710 #[test]
711 fn parameter_value_exceeding_limit_is_rejected() {
712 let long_value = vec![b'v'; MAX_PARAMETER_VALUE_BYTES + 1];
713 let mut buf = make_parameter_status(b"timezone", &long_value);
714 let result = decode_message(&mut buf);
715 assert!(
716 result.is_err(),
717 "oversized parameter value must be rejected"
718 );
719 let msg = result.unwrap_err().to_string();
720 assert!(
721 msg.contains("too long") || msg.contains("65536"),
722 "error must mention the value limit: {msg}"
723 );
724 }
725}