1use super::types::*;
4
5pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11fn decode_utf8(bytes: &[u8], context: &str) -> Result<String, String> {
12 std::str::from_utf8(bytes)
13 .map(str::to_string)
14 .map_err(|e| format!("{} is not valid UTF-8: {}", context, e))
15}
16
17impl BackendMessage {
18 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
20 if buf.len() < 5 {
21 return Err("Buffer too short".to_string());
22 }
23
24 let msg_type = buf[0];
25 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
26
27 if len < 4 {
30 return Err(format!("Invalid message length: {} (minimum is 4)", len));
31 }
32 if len > MAX_BACKEND_FRAME_LEN {
33 return Err(format!(
34 "Message too large: {} bytes (max {})",
35 len, MAX_BACKEND_FRAME_LEN
36 ));
37 }
38
39 let frame_len = len
40 .checked_add(1)
41 .ok_or_else(|| "Message length overflow".to_string())?;
42
43 if buf.len() < frame_len {
44 return Err("Incomplete message".to_string());
45 }
46
47 let payload = &buf[5..frame_len];
48
49 let message = match msg_type {
50 b'R' => Self::decode_auth(payload)?,
51 b'S' => Self::decode_parameter_status(payload)?,
52 b'K' => Self::decode_backend_key(payload)?,
53 b'v' => Self::decode_negotiate_protocol_version(payload)?,
54 b'Z' => Self::decode_ready_for_query(payload)?,
55 b'T' => Self::decode_row_description(payload)?,
56 b'D' => Self::decode_data_row(payload)?,
57 b'C' => Self::decode_command_complete(payload)?,
58 b'E' => Self::decode_error_response(payload)?,
59 b'1' => {
60 if !payload.is_empty() {
61 return Err("ParseComplete must have empty payload".to_string());
62 }
63 BackendMessage::ParseComplete
64 }
65 b'2' => {
66 if !payload.is_empty() {
67 return Err("BindComplete must have empty payload".to_string());
68 }
69 BackendMessage::BindComplete
70 }
71 b'3' => {
72 if !payload.is_empty() {
73 return Err("CloseComplete must have empty payload".to_string());
74 }
75 BackendMessage::CloseComplete
76 }
77 b'n' => {
78 if !payload.is_empty() {
79 return Err("NoData must have empty payload".to_string());
80 }
81 BackendMessage::NoData
82 }
83 b's' => {
84 if !payload.is_empty() {
85 return Err("PortalSuspended must have empty payload".to_string());
86 }
87 BackendMessage::PortalSuspended
88 }
89 b't' => Self::decode_parameter_description(payload)?,
90 b'G' => Self::decode_copy_in_response(payload)?,
91 b'H' => Self::decode_copy_out_response(payload)?,
92 b'W' => Self::decode_copy_both_response(payload)?,
93 b'd' => BackendMessage::CopyData(payload.to_vec()),
94 b'c' => {
95 if !payload.is_empty() {
96 return Err("CopyDone must have empty payload".to_string());
97 }
98 BackendMessage::CopyDone
99 }
100 b'A' => Self::decode_notification_response(payload)?,
101 b'I' => {
102 if !payload.is_empty() {
103 return Err("EmptyQueryResponse must have empty payload".to_string());
104 }
105 BackendMessage::EmptyQueryResponse
106 }
107 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
108 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
109 };
110
111 Ok((message, frame_len))
112 }
113
114 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
115 if payload.len() < 4 {
116 return Err("Auth payload too short".to_string());
117 }
118 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
119 match auth_type {
120 0 => {
121 if payload.len() != 4 {
122 return Err(format!(
123 "AuthenticationOk invalid payload length: {}",
124 payload.len()
125 ));
126 }
127 Ok(BackendMessage::AuthenticationOk)
128 }
129 2 => {
130 if payload.len() != 4 {
131 return Err(format!(
132 "AuthenticationKerberosV5 invalid payload length: {}",
133 payload.len()
134 ));
135 }
136 Ok(BackendMessage::AuthenticationKerberosV5)
137 }
138 3 => {
139 if payload.len() != 4 {
140 return Err(format!(
141 "AuthenticationCleartextPassword invalid payload length: {}",
142 payload.len()
143 ));
144 }
145 Ok(BackendMessage::AuthenticationCleartextPassword)
146 }
147 5 => {
148 if payload.len() != 8 {
149 return Err("MD5 auth payload too short (need salt)".to_string());
150 }
151 let mut salt = [0u8; 4];
152 salt.copy_from_slice(&payload[4..8]);
153 Ok(BackendMessage::AuthenticationMD5Password(salt))
154 }
155 6 => {
156 if payload.len() != 4 {
157 return Err(format!(
158 "AuthenticationSCMCredential invalid payload length: {}",
159 payload.len()
160 ));
161 }
162 Ok(BackendMessage::AuthenticationSCMCredential)
163 }
164 7 => {
165 if payload.len() != 4 {
166 return Err(format!(
167 "AuthenticationGSS invalid payload length: {}",
168 payload.len()
169 ));
170 }
171 Ok(BackendMessage::AuthenticationGSS)
172 }
173 8 => Ok(BackendMessage::AuthenticationGSSContinue(
174 payload[4..].to_vec(),
175 )),
176 9 => {
177 if payload.len() != 4 {
178 return Err(format!(
179 "AuthenticationSSPI invalid payload length: {}",
180 payload.len()
181 ));
182 }
183 Ok(BackendMessage::AuthenticationSSPI)
184 }
185 10 => {
186 let mut mechanisms = Vec::new();
188 let mut pos = 4;
189 while pos < payload.len() {
190 if payload[pos] == 0 {
191 break; }
193 let end = payload[pos..]
194 .iter()
195 .position(|&b| b == 0)
196 .map(|p| pos + p)
197 .ok_or("SASL mechanism list missing null terminator")?;
198 mechanisms.push(decode_utf8(&payload[pos..end], "SASL mechanism")?);
199 pos = end + 1;
200 }
201 if pos >= payload.len() {
202 return Err("SASL mechanism list missing final terminator".to_string());
203 }
204 if pos + 1 != payload.len() {
205 return Err("SASL mechanism list has trailing bytes".to_string());
206 }
207 if mechanisms.is_empty() {
208 return Err("SASL mechanism list is empty".to_string());
209 }
210 Ok(BackendMessage::AuthenticationSASL(mechanisms))
211 }
212 11 => {
213 Ok(BackendMessage::AuthenticationSASLContinue(
215 payload[4..].to_vec(),
216 ))
217 }
218 12 => {
219 Ok(BackendMessage::AuthenticationSASLFinal(
221 payload[4..].to_vec(),
222 ))
223 }
224 _ => Err(format!("Unknown auth type: {}", auth_type)),
225 }
226 }
227
228 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
229 let name_end = payload
230 .iter()
231 .position(|&b| b == 0)
232 .ok_or("ParameterStatus missing name terminator")?;
233 let value_start = name_end + 1;
234 if value_start > payload.len() {
235 return Err("ParameterStatus missing value".to_string());
236 }
237 let value_end_rel = payload[value_start..]
238 .iter()
239 .position(|&b| b == 0)
240 .ok_or("ParameterStatus missing value terminator")?;
241 let value_end = value_start + value_end_rel;
242 if value_end + 1 != payload.len() {
243 return Err("ParameterStatus has trailing bytes".to_string());
244 }
245 Ok(BackendMessage::ParameterStatus {
246 name: decode_utf8(&payload[..name_end], "ParameterStatus name")?,
247 value: decode_utf8(&payload[value_start..value_end], "ParameterStatus value")?,
248 })
249 }
250
251 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
252 if payload.len() < 8 {
253 return Err("BackendKeyData payload too short".to_string());
254 }
255 let key_len = payload.len() - 4;
256 if !(4..=256).contains(&key_len) {
257 return Err(format!(
258 "BackendKeyData invalid secret key length: {} (expected 4..=256)",
259 key_len
260 ));
261 }
262 Ok(BackendMessage::BackendKeyData {
263 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
264 secret_key: payload[4..].to_vec(),
265 })
266 }
267
268 fn decode_negotiate_protocol_version(payload: &[u8]) -> Result<Self, String> {
269 if payload.len() < 8 {
270 return Err("NegotiateProtocolVersion payload too short".to_string());
271 }
272
273 let newest_minor_supported =
274 i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
275 if newest_minor_supported < 0 {
276 return Err("NegotiateProtocolVersion newest_minor_supported is negative".to_string());
277 }
278
279 let unrecognized_count =
280 i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
281 if unrecognized_count < 0 {
282 return Err(
283 "NegotiateProtocolVersion unrecognized option count is negative".to_string(),
284 );
285 }
286 let unrecognized_count = unrecognized_count as usize;
287 let remaining = payload.len() - 8;
288 if unrecognized_count > remaining {
291 return Err(format!(
292 "NegotiateProtocolVersion unrecognized option count {} exceeds payload capacity {}",
293 unrecognized_count, remaining
294 ));
295 }
296
297 let mut options = Vec::with_capacity(unrecognized_count);
298 let mut pos = 8usize;
299 for _ in 0..unrecognized_count {
300 if pos >= payload.len() {
301 return Err("NegotiateProtocolVersion missing option string terminator".to_string());
302 }
303 let rel_end = payload[pos..]
304 .iter()
305 .position(|&b| b == 0)
306 .ok_or("NegotiateProtocolVersion option missing null terminator")?;
307 let end = pos + rel_end;
308 options.push(decode_utf8(
309 &payload[pos..end],
310 "NegotiateProtocolVersion option",
311 )?);
312 pos = end + 1;
313 }
314
315 if pos != payload.len() {
316 return Err("NegotiateProtocolVersion has trailing bytes".to_string());
317 }
318
319 Ok(BackendMessage::NegotiateProtocolVersion {
320 newest_minor_supported,
321 unrecognized_protocol_options: options,
322 })
323 }
324
325 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
326 if payload.len() != 1 {
327 return Err("ReadyForQuery payload empty".to_string());
328 }
329 let status = match payload[0] {
330 b'I' => TransactionStatus::Idle,
331 b'T' => TransactionStatus::InBlock,
332 b'E' => TransactionStatus::Failed,
333 _ => return Err("Unknown transaction status".to_string()),
334 };
335 Ok(BackendMessage::ReadyForQuery(status))
336 }
337
338 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
339 if payload.len() < 2 {
340 return Err("RowDescription payload too short".to_string());
341 }
342
343 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
344 if raw_count < 0 {
345 return Err(format!("RowDescription invalid field count: {}", raw_count));
346 }
347 let field_count = raw_count as usize;
348 let mut fields = Vec::with_capacity(field_count);
349 let mut pos = 2;
350
351 for _ in 0..field_count {
352 let name_end = payload[pos..]
354 .iter()
355 .position(|&b| b == 0)
356 .ok_or("Missing null terminator in field name")?;
357 let name = decode_utf8(&payload[pos..pos + name_end], "RowDescription field name")?;
358 pos += name_end + 1; if pos + 18 > payload.len() {
362 return Err("RowDescription field truncated".to_string());
363 }
364
365 let table_oid = u32::from_be_bytes([
366 payload[pos],
367 payload[pos + 1],
368 payload[pos + 2],
369 payload[pos + 3],
370 ]);
371 pos += 4;
372
373 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
374 pos += 2;
375
376 let type_oid = u32::from_be_bytes([
377 payload[pos],
378 payload[pos + 1],
379 payload[pos + 2],
380 payload[pos + 3],
381 ]);
382 pos += 4;
383
384 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
385 pos += 2;
386
387 let type_modifier = i32::from_be_bytes([
388 payload[pos],
389 payload[pos + 1],
390 payload[pos + 2],
391 payload[pos + 3],
392 ]);
393 pos += 4;
394
395 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
396 if !(0..=1).contains(&format) {
397 return Err(format!("RowDescription invalid format code: {}", format));
398 }
399 pos += 2;
400
401 fields.push(FieldDescription {
402 name,
403 table_oid,
404 column_attr,
405 type_oid,
406 type_size,
407 type_modifier,
408 format,
409 });
410 }
411
412 if pos != payload.len() {
413 return Err("RowDescription has trailing bytes".to_string());
414 }
415
416 Ok(BackendMessage::RowDescription(fields))
417 }
418
419 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
420 if payload.len() < 2 {
421 return Err("DataRow payload too short".to_string());
422 }
423
424 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
425 if raw_count < 0 {
426 return Err(format!("DataRow invalid column count: {}", raw_count));
427 }
428 let column_count = raw_count as usize;
429 if column_count > (payload.len() - 2) / 4 + 1 {
431 return Err(format!(
432 "DataRow claims {} columns but payload is only {} bytes",
433 column_count,
434 payload.len()
435 ));
436 }
437 let mut columns = Vec::with_capacity(column_count);
438 let mut pos = 2;
439
440 for _ in 0..column_count {
441 if pos + 4 > payload.len() {
442 return Err("DataRow truncated".to_string());
443 }
444
445 let len = i32::from_be_bytes([
446 payload[pos],
447 payload[pos + 1],
448 payload[pos + 2],
449 payload[pos + 3],
450 ]);
451 pos += 4;
452
453 if len == -1 {
454 columns.push(None);
456 } else {
457 if len < -1 {
458 return Err(format!("DataRow invalid column length: {}", len));
459 }
460 let len = len as usize;
461 if len > payload.len().saturating_sub(pos) {
462 return Err("DataRow column data truncated".to_string());
463 }
464 let data = payload[pos..pos + len].to_vec();
465 pos += len;
466 columns.push(Some(data));
467 }
468 }
469
470 if pos != payload.len() {
471 return Err("DataRow has trailing bytes".to_string());
472 }
473
474 Ok(BackendMessage::DataRow(columns))
475 }
476
477 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
478 if payload.last().copied() != Some(0) {
479 return Err("CommandComplete missing null terminator".to_string());
480 }
481 let tag_bytes = &payload[..payload.len() - 1];
482 if tag_bytes.contains(&0) {
483 return Err("CommandComplete contains interior null byte".to_string());
484 }
485 let tag = decode_utf8(tag_bytes, "CommandComplete tag")?;
486 Ok(BackendMessage::CommandComplete(tag))
487 }
488
489 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
490 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
491 payload,
492 )?))
493 }
494
495 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
496 if payload.last().copied() != Some(0) {
497 return Err("ErrorResponse missing final terminator".to_string());
498 }
499 let mut fields = ErrorFields::default();
500 let mut i = 0;
501 while i < payload.len() && payload[i] != 0 {
502 let field_type = payload[i];
503 i += 1;
504 let end = payload[i..]
505 .iter()
506 .position(|&b| b == 0)
507 .map(|p| p + i)
508 .ok_or("ErrorResponse field missing null terminator")?;
509 let value = decode_utf8(&payload[i..end], "ErrorResponse field")?;
510 i = end + 1;
511
512 match field_type {
513 b'S' => fields.severity = value,
514 b'C' => fields.code = value,
515 b'M' => fields.message = value,
516 b'D' => fields.detail = Some(value),
517 b'H' => fields.hint = Some(value),
518 _ => {}
519 }
520 }
521 if i + 1 != payload.len() {
522 return Err("ErrorResponse has trailing bytes after terminator".to_string());
523 }
524 Ok(fields)
525 }
526
527 fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
528 if payload.len() < 2 {
529 return Err("ParameterDescription payload too short".to_string());
530 }
531 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
532 if raw_count < 0 {
533 return Err(format!("ParameterDescription invalid count: {}", raw_count));
534 }
535 let count = raw_count as usize;
536 let expected_len = 2 + count * 4;
537 if payload.len() < expected_len {
538 return Err(format!(
539 "ParameterDescription truncated: expected {} bytes, got {}",
540 expected_len,
541 payload.len()
542 ));
543 }
544 let mut oids = Vec::with_capacity(count);
545 let mut pos = 2;
546 for _ in 0..count {
547 oids.push(u32::from_be_bytes([
548 payload[pos],
549 payload[pos + 1],
550 payload[pos + 2],
551 payload[pos + 3],
552 ]));
553 pos += 4;
554 }
555 if pos != payload.len() {
556 return Err("ParameterDescription has trailing bytes".to_string());
557 }
558 Ok(BackendMessage::ParameterDescription(oids))
559 }
560
561 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
562 if payload.len() < 3 {
563 return Err("CopyInResponse payload too short".to_string());
564 }
565 let format = payload[0];
566 if format > 1 {
567 return Err(format!(
568 "CopyInResponse invalid overall format code: {}",
569 format
570 ));
571 }
572 let num_columns = if payload.len() >= 3 {
573 let raw = i16::from_be_bytes([payload[1], payload[2]]);
574 if raw < 0 {
575 return Err(format!(
576 "CopyInResponse invalid negative column count: {}",
577 raw
578 ));
579 }
580 raw as usize
581 } else {
582 0
583 };
584 let mut column_formats = Vec::with_capacity(num_columns);
585 let mut pos = 3usize;
586 for _ in 0..num_columns {
587 if pos + 2 > payload.len() {
588 return Err("CopyInResponse truncated column format list".to_string());
589 }
590 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
591 if !(0..=1).contains(&raw) {
592 return Err(format!("CopyInResponse invalid format code: {}", raw));
593 }
594 column_formats.push(raw as u8);
595 pos += 2;
596 }
597 if pos != payload.len() {
598 return Err("CopyInResponse has trailing bytes".to_string());
599 }
600 Ok(BackendMessage::CopyInResponse {
601 format,
602 column_formats,
603 })
604 }
605
606 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
607 if payload.len() < 3 {
608 return Err("CopyOutResponse payload too short".to_string());
609 }
610 let format = payload[0];
611 if format > 1 {
612 return Err(format!(
613 "CopyOutResponse invalid overall format code: {}",
614 format
615 ));
616 }
617 let num_columns = if payload.len() >= 3 {
618 let raw = i16::from_be_bytes([payload[1], payload[2]]);
619 if raw < 0 {
620 return Err(format!(
621 "CopyOutResponse invalid negative column count: {}",
622 raw
623 ));
624 }
625 raw as usize
626 } else {
627 0
628 };
629 let mut column_formats = Vec::with_capacity(num_columns);
630 let mut pos = 3usize;
631 for _ in 0..num_columns {
632 if pos + 2 > payload.len() {
633 return Err("CopyOutResponse truncated column format list".to_string());
634 }
635 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
636 if !(0..=1).contains(&raw) {
637 return Err(format!("CopyOutResponse invalid format code: {}", raw));
638 }
639 column_formats.push(raw as u8);
640 pos += 2;
641 }
642 if pos != payload.len() {
643 return Err("CopyOutResponse has trailing bytes".to_string());
644 }
645 Ok(BackendMessage::CopyOutResponse {
646 format,
647 column_formats,
648 })
649 }
650
651 fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
652 if payload.len() < 3 {
653 return Err("CopyBothResponse payload too short".to_string());
654 }
655 let format = payload[0];
656 if format > 1 {
657 return Err(format!(
658 "CopyBothResponse invalid overall format code: {}",
659 format
660 ));
661 }
662 let num_columns = if payload.len() >= 3 {
663 let raw = i16::from_be_bytes([payload[1], payload[2]]);
664 if raw < 0 {
665 return Err(format!(
666 "CopyBothResponse invalid negative column count: {}",
667 raw
668 ));
669 }
670 raw as usize
671 } else {
672 0
673 };
674 let mut column_formats = Vec::with_capacity(num_columns);
675 let mut pos = 3usize;
676 for _ in 0..num_columns {
677 if pos + 2 > payload.len() {
678 return Err("CopyBothResponse truncated column format list".to_string());
679 }
680 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
681 if !(0..=1).contains(&raw) {
682 return Err(format!("CopyBothResponse invalid format code: {}", raw));
683 }
684 column_formats.push(raw as u8);
685 pos += 2;
686 }
687 if pos != payload.len() {
688 return Err("CopyBothResponse has trailing bytes".to_string());
689 }
690 Ok(BackendMessage::CopyBothResponse {
691 format,
692 column_formats,
693 })
694 }
695
696 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
697 if payload.len() < 6 {
698 return Err("NotificationResponse too short".to_string());
700 }
701 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
702
703 let mut i = 4;
705 let remaining = payload.get(i..).unwrap_or(&[]);
706 let channel_end = remaining
707 .iter()
708 .position(|&b| b == 0)
709 .ok_or("NotificationResponse: missing channel null terminator")?;
710 let channel = decode_utf8(&remaining[..channel_end], "NotificationResponse channel")?;
711 i += channel_end + 1;
712
713 let remaining = payload.get(i..).unwrap_or(&[]);
715 let payload_end = remaining
716 .iter()
717 .position(|&b| b == 0)
718 .ok_or("NotificationResponse: missing payload null terminator")?;
719 let notification_payload =
720 decode_utf8(&remaining[..payload_end], "NotificationResponse payload")?;
721 if i + payload_end + 1 != payload.len() {
722 return Err("NotificationResponse has trailing bytes".to_string());
723 }
724
725 Ok(BackendMessage::NotificationResponse {
726 process_id,
727 channel,
728 payload: notification_payload,
729 })
730 }
731}