1use super::types::*;
4
5pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11impl BackendMessage {
12 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
14 if buf.len() < 5 {
15 return Err("Buffer too short".to_string());
16 }
17
18 let msg_type = buf[0];
19 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
20
21 if len < 4 {
24 return Err(format!("Invalid message length: {} (minimum is 4)", len));
25 }
26 if len > MAX_BACKEND_FRAME_LEN {
27 return Err(format!(
28 "Message too large: {} bytes (max {})",
29 len, MAX_BACKEND_FRAME_LEN
30 ));
31 }
32
33 let frame_len = len
34 .checked_add(1)
35 .ok_or_else(|| "Message length overflow".to_string())?;
36
37 if buf.len() < frame_len {
38 return Err("Incomplete message".to_string());
39 }
40
41 let payload = &buf[5..frame_len];
42
43 let message = match msg_type {
44 b'R' => Self::decode_auth(payload)?,
45 b'S' => Self::decode_parameter_status(payload)?,
46 b'K' => Self::decode_backend_key(payload)?,
47 b'v' => Self::decode_negotiate_protocol_version(payload)?,
48 b'Z' => Self::decode_ready_for_query(payload)?,
49 b'T' => Self::decode_row_description(payload)?,
50 b'D' => Self::decode_data_row(payload)?,
51 b'C' => Self::decode_command_complete(payload)?,
52 b'E' => Self::decode_error_response(payload)?,
53 b'1' => {
54 if !payload.is_empty() {
55 return Err("ParseComplete must have empty payload".to_string());
56 }
57 BackendMessage::ParseComplete
58 }
59 b'2' => {
60 if !payload.is_empty() {
61 return Err("BindComplete must have empty payload".to_string());
62 }
63 BackendMessage::BindComplete
64 }
65 b'3' => {
66 if !payload.is_empty() {
67 return Err("CloseComplete must have empty payload".to_string());
68 }
69 BackendMessage::CloseComplete
70 }
71 b'n' => {
72 if !payload.is_empty() {
73 return Err("NoData must have empty payload".to_string());
74 }
75 BackendMessage::NoData
76 }
77 b's' => {
78 if !payload.is_empty() {
79 return Err("PortalSuspended must have empty payload".to_string());
80 }
81 BackendMessage::PortalSuspended
82 }
83 b't' => Self::decode_parameter_description(payload)?,
84 b'G' => Self::decode_copy_in_response(payload)?,
85 b'H' => Self::decode_copy_out_response(payload)?,
86 b'W' => Self::decode_copy_both_response(payload)?,
87 b'd' => BackendMessage::CopyData(payload.to_vec()),
88 b'c' => {
89 if !payload.is_empty() {
90 return Err("CopyDone must have empty payload".to_string());
91 }
92 BackendMessage::CopyDone
93 }
94 b'A' => Self::decode_notification_response(payload)?,
95 b'I' => {
96 if !payload.is_empty() {
97 return Err("EmptyQueryResponse must have empty payload".to_string());
98 }
99 BackendMessage::EmptyQueryResponse
100 }
101 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
102 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
103 };
104
105 Ok((message, frame_len))
106 }
107
108 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
109 if payload.len() < 4 {
110 return Err("Auth payload too short".to_string());
111 }
112 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
113 match auth_type {
114 0 => {
115 if payload.len() != 4 {
116 return Err(format!(
117 "AuthenticationOk invalid payload length: {}",
118 payload.len()
119 ));
120 }
121 Ok(BackendMessage::AuthenticationOk)
122 }
123 2 => {
124 if payload.len() != 4 {
125 return Err(format!(
126 "AuthenticationKerberosV5 invalid payload length: {}",
127 payload.len()
128 ));
129 }
130 Ok(BackendMessage::AuthenticationKerberosV5)
131 }
132 3 => {
133 if payload.len() != 4 {
134 return Err(format!(
135 "AuthenticationCleartextPassword invalid payload length: {}",
136 payload.len()
137 ));
138 }
139 Ok(BackendMessage::AuthenticationCleartextPassword)
140 }
141 5 => {
142 if payload.len() != 8 {
143 return Err("MD5 auth payload too short (need salt)".to_string());
144 }
145 let mut salt = [0u8; 4];
146 salt.copy_from_slice(&payload[4..8]);
147 Ok(BackendMessage::AuthenticationMD5Password(salt))
148 }
149 6 => {
150 if payload.len() != 4 {
151 return Err(format!(
152 "AuthenticationSCMCredential invalid payload length: {}",
153 payload.len()
154 ));
155 }
156 Ok(BackendMessage::AuthenticationSCMCredential)
157 }
158 7 => {
159 if payload.len() != 4 {
160 return Err(format!(
161 "AuthenticationGSS invalid payload length: {}",
162 payload.len()
163 ));
164 }
165 Ok(BackendMessage::AuthenticationGSS)
166 }
167 8 => Ok(BackendMessage::AuthenticationGSSContinue(
168 payload[4..].to_vec(),
169 )),
170 9 => {
171 if payload.len() != 4 {
172 return Err(format!(
173 "AuthenticationSSPI invalid payload length: {}",
174 payload.len()
175 ));
176 }
177 Ok(BackendMessage::AuthenticationSSPI)
178 }
179 10 => {
180 let mut mechanisms = Vec::new();
182 let mut pos = 4;
183 while pos < payload.len() {
184 if payload[pos] == 0 {
185 break; }
187 let end = payload[pos..]
188 .iter()
189 .position(|&b| b == 0)
190 .map(|p| pos + p)
191 .ok_or("SASL mechanism list missing null terminator")?;
192 mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
193 pos = end + 1;
194 }
195 if pos >= payload.len() {
196 return Err("SASL mechanism list missing final terminator".to_string());
197 }
198 if pos + 1 != payload.len() {
199 return Err("SASL mechanism list has trailing bytes".to_string());
200 }
201 if mechanisms.is_empty() {
202 return Err("SASL mechanism list is empty".to_string());
203 }
204 Ok(BackendMessage::AuthenticationSASL(mechanisms))
205 }
206 11 => {
207 Ok(BackendMessage::AuthenticationSASLContinue(
209 payload[4..].to_vec(),
210 ))
211 }
212 12 => {
213 Ok(BackendMessage::AuthenticationSASLFinal(
215 payload[4..].to_vec(),
216 ))
217 }
218 _ => Err(format!("Unknown auth type: {}", auth_type)),
219 }
220 }
221
222 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
223 let name_end = payload
224 .iter()
225 .position(|&b| b == 0)
226 .ok_or("ParameterStatus missing name terminator")?;
227 let value_start = name_end + 1;
228 if value_start > payload.len() {
229 return Err("ParameterStatus missing value".to_string());
230 }
231 let value_end_rel = payload[value_start..]
232 .iter()
233 .position(|&b| b == 0)
234 .ok_or("ParameterStatus missing value terminator")?;
235 let value_end = value_start + value_end_rel;
236 if value_end + 1 != payload.len() {
237 return Err("ParameterStatus has trailing bytes".to_string());
238 }
239 Ok(BackendMessage::ParameterStatus {
240 name: String::from_utf8_lossy(&payload[..name_end]).to_string(),
241 value: String::from_utf8_lossy(&payload[value_start..value_end]).to_string(),
242 })
243 }
244
245 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
246 if payload.len() < 8 {
247 return Err("BackendKeyData payload too short".to_string());
248 }
249 let key_len = payload.len() - 4;
250 if !(4..=256).contains(&key_len) {
251 return Err(format!(
252 "BackendKeyData invalid secret key length: {} (expected 4..=256)",
253 key_len
254 ));
255 }
256 Ok(BackendMessage::BackendKeyData {
257 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
258 secret_key: payload[4..].to_vec(),
259 })
260 }
261
262 fn decode_negotiate_protocol_version(payload: &[u8]) -> Result<Self, String> {
263 if payload.len() < 8 {
264 return Err("NegotiateProtocolVersion payload too short".to_string());
265 }
266
267 let newest_minor_supported =
268 i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
269 if newest_minor_supported < 0 {
270 return Err("NegotiateProtocolVersion newest_minor_supported is negative".to_string());
271 }
272
273 let unrecognized_count =
274 i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
275 if unrecognized_count < 0 {
276 return Err(
277 "NegotiateProtocolVersion unrecognized option count is negative".to_string(),
278 );
279 }
280 let unrecognized_count = unrecognized_count as usize;
281
282 let mut options = Vec::with_capacity(unrecognized_count);
283 let mut pos = 8usize;
284 for _ in 0..unrecognized_count {
285 if pos >= payload.len() {
286 return Err("NegotiateProtocolVersion missing option string terminator".to_string());
287 }
288 let rel_end = payload[pos..]
289 .iter()
290 .position(|&b| b == 0)
291 .ok_or("NegotiateProtocolVersion option missing null terminator")?;
292 let end = pos + rel_end;
293 options.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
294 pos = end + 1;
295 }
296
297 if pos != payload.len() {
298 return Err("NegotiateProtocolVersion has trailing bytes".to_string());
299 }
300
301 Ok(BackendMessage::NegotiateProtocolVersion {
302 newest_minor_supported,
303 unrecognized_protocol_options: options,
304 })
305 }
306
307 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
308 if payload.len() != 1 {
309 return Err("ReadyForQuery payload empty".to_string());
310 }
311 let status = match payload[0] {
312 b'I' => TransactionStatus::Idle,
313 b'T' => TransactionStatus::InBlock,
314 b'E' => TransactionStatus::Failed,
315 _ => return Err("Unknown transaction status".to_string()),
316 };
317 Ok(BackendMessage::ReadyForQuery(status))
318 }
319
320 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
321 if payload.len() < 2 {
322 return Err("RowDescription payload too short".to_string());
323 }
324
325 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
326 if raw_count < 0 {
327 return Err(format!("RowDescription invalid field count: {}", raw_count));
328 }
329 let field_count = raw_count as usize;
330 let mut fields = Vec::with_capacity(field_count);
331 let mut pos = 2;
332
333 for _ in 0..field_count {
334 let name_end = payload[pos..]
336 .iter()
337 .position(|&b| b == 0)
338 .ok_or("Missing null terminator in field name")?;
339 let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
340 pos += name_end + 1; if pos + 18 > payload.len() {
344 return Err("RowDescription field truncated".to_string());
345 }
346
347 let table_oid = u32::from_be_bytes([
348 payload[pos],
349 payload[pos + 1],
350 payload[pos + 2],
351 payload[pos + 3],
352 ]);
353 pos += 4;
354
355 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
356 pos += 2;
357
358 let type_oid = u32::from_be_bytes([
359 payload[pos],
360 payload[pos + 1],
361 payload[pos + 2],
362 payload[pos + 3],
363 ]);
364 pos += 4;
365
366 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
367 pos += 2;
368
369 let type_modifier = i32::from_be_bytes([
370 payload[pos],
371 payload[pos + 1],
372 payload[pos + 2],
373 payload[pos + 3],
374 ]);
375 pos += 4;
376
377 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
378 if !(0..=1).contains(&format) {
379 return Err(format!("RowDescription invalid format code: {}", format));
380 }
381 pos += 2;
382
383 fields.push(FieldDescription {
384 name,
385 table_oid,
386 column_attr,
387 type_oid,
388 type_size,
389 type_modifier,
390 format,
391 });
392 }
393
394 if pos != payload.len() {
395 return Err("RowDescription has trailing bytes".to_string());
396 }
397
398 Ok(BackendMessage::RowDescription(fields))
399 }
400
401 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
402 if payload.len() < 2 {
403 return Err("DataRow payload too short".to_string());
404 }
405
406 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
407 if raw_count < 0 {
408 return Err(format!("DataRow invalid column count: {}", raw_count));
409 }
410 let column_count = raw_count as usize;
411 if column_count > (payload.len() - 2) / 4 + 1 {
413 return Err(format!(
414 "DataRow claims {} columns but payload is only {} bytes",
415 column_count,
416 payload.len()
417 ));
418 }
419 let mut columns = Vec::with_capacity(column_count);
420 let mut pos = 2;
421
422 for _ in 0..column_count {
423 if pos + 4 > payload.len() {
424 return Err("DataRow truncated".to_string());
425 }
426
427 let len = i32::from_be_bytes([
428 payload[pos],
429 payload[pos + 1],
430 payload[pos + 2],
431 payload[pos + 3],
432 ]);
433 pos += 4;
434
435 if len == -1 {
436 columns.push(None);
438 } else {
439 if len < -1 {
440 return Err(format!("DataRow invalid column length: {}", len));
441 }
442 let len = len as usize;
443 if len > payload.len().saturating_sub(pos) {
444 return Err("DataRow column data truncated".to_string());
445 }
446 let data = payload[pos..pos + len].to_vec();
447 pos += len;
448 columns.push(Some(data));
449 }
450 }
451
452 if pos != payload.len() {
453 return Err("DataRow has trailing bytes".to_string());
454 }
455
456 Ok(BackendMessage::DataRow(columns))
457 }
458
459 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
460 if payload.last().copied() != Some(0) {
461 return Err("CommandComplete missing null terminator".to_string());
462 }
463 let tag_bytes = &payload[..payload.len() - 1];
464 if tag_bytes.contains(&0) {
465 return Err("CommandComplete contains interior null byte".to_string());
466 }
467 let tag = String::from_utf8_lossy(tag_bytes).to_string();
468 Ok(BackendMessage::CommandComplete(tag))
469 }
470
471 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
472 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
473 payload,
474 )?))
475 }
476
477 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
478 if payload.last().copied() != Some(0) {
479 return Err("ErrorResponse missing final terminator".to_string());
480 }
481 let mut fields = ErrorFields::default();
482 let mut i = 0;
483 while i < payload.len() && payload[i] != 0 {
484 let field_type = payload[i];
485 i += 1;
486 let end = payload[i..]
487 .iter()
488 .position(|&b| b == 0)
489 .map(|p| p + i)
490 .ok_or("ErrorResponse field missing null terminator")?;
491 let value = String::from_utf8_lossy(&payload[i..end]).to_string();
492 i = end + 1;
493
494 match field_type {
495 b'S' => fields.severity = value,
496 b'C' => fields.code = value,
497 b'M' => fields.message = value,
498 b'D' => fields.detail = Some(value),
499 b'H' => fields.hint = Some(value),
500 _ => {}
501 }
502 }
503 if i + 1 != payload.len() {
504 return Err("ErrorResponse has trailing bytes after terminator".to_string());
505 }
506 Ok(fields)
507 }
508
509 fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
510 if payload.len() < 2 {
511 return Err("ParameterDescription payload too short".to_string());
512 }
513 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
514 if raw_count < 0 {
515 return Err(format!("ParameterDescription invalid count: {}", raw_count));
516 }
517 let count = raw_count as usize;
518 let expected_len = 2 + count * 4;
519 if payload.len() < expected_len {
520 return Err(format!(
521 "ParameterDescription truncated: expected {} bytes, got {}",
522 expected_len,
523 payload.len()
524 ));
525 }
526 let mut oids = Vec::with_capacity(count);
527 let mut pos = 2;
528 for _ in 0..count {
529 oids.push(u32::from_be_bytes([
530 payload[pos],
531 payload[pos + 1],
532 payload[pos + 2],
533 payload[pos + 3],
534 ]));
535 pos += 4;
536 }
537 if pos != payload.len() {
538 return Err("ParameterDescription has trailing bytes".to_string());
539 }
540 Ok(BackendMessage::ParameterDescription(oids))
541 }
542
543 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
544 if payload.len() < 3 {
545 return Err("CopyInResponse payload too short".to_string());
546 }
547 let format = payload[0];
548 if format > 1 {
549 return Err(format!(
550 "CopyInResponse invalid overall format code: {}",
551 format
552 ));
553 }
554 let num_columns = if payload.len() >= 3 {
555 let raw = i16::from_be_bytes([payload[1], payload[2]]);
556 if raw < 0 {
557 return Err(format!(
558 "CopyInResponse invalid negative column count: {}",
559 raw
560 ));
561 }
562 raw as usize
563 } else {
564 0
565 };
566 let mut column_formats = Vec::with_capacity(num_columns);
567 let mut pos = 3usize;
568 for _ in 0..num_columns {
569 if pos + 2 > payload.len() {
570 return Err("CopyInResponse truncated column format list".to_string());
571 }
572 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
573 if !(0..=1).contains(&raw) {
574 return Err(format!("CopyInResponse invalid format code: {}", raw));
575 }
576 column_formats.push(raw as u8);
577 pos += 2;
578 }
579 if pos != payload.len() {
580 return Err("CopyInResponse has trailing bytes".to_string());
581 }
582 Ok(BackendMessage::CopyInResponse {
583 format,
584 column_formats,
585 })
586 }
587
588 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
589 if payload.len() < 3 {
590 return Err("CopyOutResponse payload too short".to_string());
591 }
592 let format = payload[0];
593 if format > 1 {
594 return Err(format!(
595 "CopyOutResponse invalid overall format code: {}",
596 format
597 ));
598 }
599 let num_columns = if payload.len() >= 3 {
600 let raw = i16::from_be_bytes([payload[1], payload[2]]);
601 if raw < 0 {
602 return Err(format!(
603 "CopyOutResponse invalid negative column count: {}",
604 raw
605 ));
606 }
607 raw as usize
608 } else {
609 0
610 };
611 let mut column_formats = Vec::with_capacity(num_columns);
612 let mut pos = 3usize;
613 for _ in 0..num_columns {
614 if pos + 2 > payload.len() {
615 return Err("CopyOutResponse truncated column format list".to_string());
616 }
617 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
618 if !(0..=1).contains(&raw) {
619 return Err(format!("CopyOutResponse invalid format code: {}", raw));
620 }
621 column_formats.push(raw as u8);
622 pos += 2;
623 }
624 if pos != payload.len() {
625 return Err("CopyOutResponse has trailing bytes".to_string());
626 }
627 Ok(BackendMessage::CopyOutResponse {
628 format,
629 column_formats,
630 })
631 }
632
633 fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
634 if payload.len() < 3 {
635 return Err("CopyBothResponse payload too short".to_string());
636 }
637 let format = payload[0];
638 if format > 1 {
639 return Err(format!(
640 "CopyBothResponse invalid overall format code: {}",
641 format
642 ));
643 }
644 let num_columns = if payload.len() >= 3 {
645 let raw = i16::from_be_bytes([payload[1], payload[2]]);
646 if raw < 0 {
647 return Err(format!(
648 "CopyBothResponse invalid negative column count: {}",
649 raw
650 ));
651 }
652 raw as usize
653 } else {
654 0
655 };
656 let mut column_formats = Vec::with_capacity(num_columns);
657 let mut pos = 3usize;
658 for _ in 0..num_columns {
659 if pos + 2 > payload.len() {
660 return Err("CopyBothResponse truncated column format list".to_string());
661 }
662 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
663 if !(0..=1).contains(&raw) {
664 return Err(format!("CopyBothResponse invalid format code: {}", raw));
665 }
666 column_formats.push(raw as u8);
667 pos += 2;
668 }
669 if pos != payload.len() {
670 return Err("CopyBothResponse has trailing bytes".to_string());
671 }
672 Ok(BackendMessage::CopyBothResponse {
673 format,
674 column_formats,
675 })
676 }
677
678 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
679 if payload.len() < 6 {
680 return Err("NotificationResponse too short".to_string());
682 }
683 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
684
685 let mut i = 4;
687 let remaining = payload.get(i..).unwrap_or(&[]);
688 let channel_end = remaining
689 .iter()
690 .position(|&b| b == 0)
691 .ok_or("NotificationResponse: missing channel null terminator")?;
692 let channel = String::from_utf8_lossy(&remaining[..channel_end]).to_string();
693 i += channel_end + 1;
694
695 let remaining = payload.get(i..).unwrap_or(&[]);
697 let payload_end = remaining
698 .iter()
699 .position(|&b| b == 0)
700 .ok_or("NotificationResponse: missing payload null terminator")?;
701 let notification_payload = String::from_utf8_lossy(&remaining[..payload_end]).to_string();
702 if i + payload_end + 1 != payload.len() {
703 return Err("NotificationResponse has trailing bytes".to_string());
704 }
705
706 Ok(BackendMessage::NotificationResponse {
707 process_id,
708 channel,
709 payload: notification_payload,
710 })
711 }
712}