1use super::types::*;
4
5pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11impl BackendMessage {
12 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
14 if buf.len() < 5 {
15 return Err("Buffer too short".to_string());
16 }
17
18 let msg_type = buf[0];
19 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
20
21 if len < 4 {
24 return Err(format!("Invalid message length: {} (minimum is 4)", len));
25 }
26 if len > MAX_BACKEND_FRAME_LEN {
27 return Err(format!(
28 "Message too large: {} bytes (max {})",
29 len, MAX_BACKEND_FRAME_LEN
30 ));
31 }
32
33 let frame_len = len
34 .checked_add(1)
35 .ok_or_else(|| "Message length overflow".to_string())?;
36
37 if buf.len() < frame_len {
38 return Err("Incomplete message".to_string());
39 }
40
41 let payload = &buf[5..frame_len];
42
43 let message = match msg_type {
44 b'R' => Self::decode_auth(payload)?,
45 b'S' => Self::decode_parameter_status(payload)?,
46 b'K' => Self::decode_backend_key(payload)?,
47 b'Z' => Self::decode_ready_for_query(payload)?,
48 b'T' => Self::decode_row_description(payload)?,
49 b'D' => Self::decode_data_row(payload)?,
50 b'C' => Self::decode_command_complete(payload)?,
51 b'E' => Self::decode_error_response(payload)?,
52 b'1' => {
53 if !payload.is_empty() {
54 return Err("ParseComplete must have empty payload".to_string());
55 }
56 BackendMessage::ParseComplete
57 }
58 b'2' => {
59 if !payload.is_empty() {
60 return Err("BindComplete must have empty payload".to_string());
61 }
62 BackendMessage::BindComplete
63 }
64 b'3' => {
65 if !payload.is_empty() {
66 return Err("CloseComplete must have empty payload".to_string());
67 }
68 BackendMessage::CloseComplete
69 }
70 b'n' => {
71 if !payload.is_empty() {
72 return Err("NoData must have empty payload".to_string());
73 }
74 BackendMessage::NoData
75 }
76 b's' => {
77 if !payload.is_empty() {
78 return Err("PortalSuspended must have empty payload".to_string());
79 }
80 BackendMessage::PortalSuspended
81 }
82 b't' => Self::decode_parameter_description(payload)?,
83 b'G' => Self::decode_copy_in_response(payload)?,
84 b'H' => Self::decode_copy_out_response(payload)?,
85 b'W' => Self::decode_copy_both_response(payload)?,
86 b'd' => BackendMessage::CopyData(payload.to_vec()),
87 b'c' => {
88 if !payload.is_empty() {
89 return Err("CopyDone must have empty payload".to_string());
90 }
91 BackendMessage::CopyDone
92 }
93 b'A' => Self::decode_notification_response(payload)?,
94 b'I' => {
95 if !payload.is_empty() {
96 return Err("EmptyQueryResponse must have empty payload".to_string());
97 }
98 BackendMessage::EmptyQueryResponse
99 }
100 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
101 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
102 };
103
104 Ok((message, frame_len))
105 }
106
107 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
108 if payload.len() < 4 {
109 return Err("Auth payload too short".to_string());
110 }
111 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
112 match auth_type {
113 0 => {
114 if payload.len() != 4 {
115 return Err(format!(
116 "AuthenticationOk invalid payload length: {}",
117 payload.len()
118 ));
119 }
120 Ok(BackendMessage::AuthenticationOk)
121 }
122 2 => {
123 if payload.len() != 4 {
124 return Err(format!(
125 "AuthenticationKerberosV5 invalid payload length: {}",
126 payload.len()
127 ));
128 }
129 Ok(BackendMessage::AuthenticationKerberosV5)
130 }
131 3 => {
132 if payload.len() != 4 {
133 return Err(format!(
134 "AuthenticationCleartextPassword invalid payload length: {}",
135 payload.len()
136 ));
137 }
138 Ok(BackendMessage::AuthenticationCleartextPassword)
139 }
140 5 => {
141 if payload.len() != 8 {
142 return Err("MD5 auth payload too short (need salt)".to_string());
143 }
144 let mut salt = [0u8; 4];
145 salt.copy_from_slice(&payload[4..8]);
146 Ok(BackendMessage::AuthenticationMD5Password(salt))
147 }
148 6 => {
149 if payload.len() != 4 {
150 return Err(format!(
151 "AuthenticationSCMCredential invalid payload length: {}",
152 payload.len()
153 ));
154 }
155 Ok(BackendMessage::AuthenticationSCMCredential)
156 }
157 7 => {
158 if payload.len() != 4 {
159 return Err(format!(
160 "AuthenticationGSS invalid payload length: {}",
161 payload.len()
162 ));
163 }
164 Ok(BackendMessage::AuthenticationGSS)
165 }
166 8 => Ok(BackendMessage::AuthenticationGSSContinue(
167 payload[4..].to_vec(),
168 )),
169 9 => {
170 if payload.len() != 4 {
171 return Err(format!(
172 "AuthenticationSSPI invalid payload length: {}",
173 payload.len()
174 ));
175 }
176 Ok(BackendMessage::AuthenticationSSPI)
177 }
178 10 => {
179 let mut mechanisms = Vec::new();
181 let mut pos = 4;
182 while pos < payload.len() {
183 if payload[pos] == 0 {
184 break; }
186 let end = payload[pos..]
187 .iter()
188 .position(|&b| b == 0)
189 .map(|p| pos + p)
190 .ok_or("SASL mechanism list missing null terminator")?;
191 mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
192 pos = end + 1;
193 }
194 if pos >= payload.len() {
195 return Err("SASL mechanism list missing final terminator".to_string());
196 }
197 if pos + 1 != payload.len() {
198 return Err("SASL mechanism list has trailing bytes".to_string());
199 }
200 if mechanisms.is_empty() {
201 return Err("SASL mechanism list is empty".to_string());
202 }
203 Ok(BackendMessage::AuthenticationSASL(mechanisms))
204 }
205 11 => {
206 Ok(BackendMessage::AuthenticationSASLContinue(
208 payload[4..].to_vec(),
209 ))
210 }
211 12 => {
212 Ok(BackendMessage::AuthenticationSASLFinal(
214 payload[4..].to_vec(),
215 ))
216 }
217 _ => Err(format!("Unknown auth type: {}", auth_type)),
218 }
219 }
220
221 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
222 let name_end = payload
223 .iter()
224 .position(|&b| b == 0)
225 .ok_or("ParameterStatus missing name terminator")?;
226 let value_start = name_end + 1;
227 if value_start > payload.len() {
228 return Err("ParameterStatus missing value".to_string());
229 }
230 let value_end_rel = payload[value_start..]
231 .iter()
232 .position(|&b| b == 0)
233 .ok_or("ParameterStatus missing value terminator")?;
234 let value_end = value_start + value_end_rel;
235 if value_end + 1 != payload.len() {
236 return Err("ParameterStatus has trailing bytes".to_string());
237 }
238 Ok(BackendMessage::ParameterStatus {
239 name: String::from_utf8_lossy(&payload[..name_end]).to_string(),
240 value: String::from_utf8_lossy(&payload[value_start..value_end]).to_string(),
241 })
242 }
243
244 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
245 if payload.len() != 8 {
246 return Err("BackendKeyData payload too short".to_string());
247 }
248 Ok(BackendMessage::BackendKeyData {
249 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
250 secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
251 })
252 }
253
254 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
255 if payload.len() != 1 {
256 return Err("ReadyForQuery payload empty".to_string());
257 }
258 let status = match payload[0] {
259 b'I' => TransactionStatus::Idle,
260 b'T' => TransactionStatus::InBlock,
261 b'E' => TransactionStatus::Failed,
262 _ => return Err("Unknown transaction status".to_string()),
263 };
264 Ok(BackendMessage::ReadyForQuery(status))
265 }
266
267 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
268 if payload.len() < 2 {
269 return Err("RowDescription payload too short".to_string());
270 }
271
272 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
273 if raw_count < 0 {
274 return Err(format!("RowDescription invalid field count: {}", raw_count));
275 }
276 let field_count = raw_count as usize;
277 let mut fields = Vec::with_capacity(field_count);
278 let mut pos = 2;
279
280 for _ in 0..field_count {
281 let name_end = payload[pos..]
283 .iter()
284 .position(|&b| b == 0)
285 .ok_or("Missing null terminator in field name")?;
286 let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
287 pos += name_end + 1; if pos + 18 > payload.len() {
291 return Err("RowDescription field truncated".to_string());
292 }
293
294 let table_oid = u32::from_be_bytes([
295 payload[pos],
296 payload[pos + 1],
297 payload[pos + 2],
298 payload[pos + 3],
299 ]);
300 pos += 4;
301
302 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
303 pos += 2;
304
305 let type_oid = u32::from_be_bytes([
306 payload[pos],
307 payload[pos + 1],
308 payload[pos + 2],
309 payload[pos + 3],
310 ]);
311 pos += 4;
312
313 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
314 pos += 2;
315
316 let type_modifier = i32::from_be_bytes([
317 payload[pos],
318 payload[pos + 1],
319 payload[pos + 2],
320 payload[pos + 3],
321 ]);
322 pos += 4;
323
324 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
325 if !(0..=1).contains(&format) {
326 return Err(format!("RowDescription invalid format code: {}", format));
327 }
328 pos += 2;
329
330 fields.push(FieldDescription {
331 name,
332 table_oid,
333 column_attr,
334 type_oid,
335 type_size,
336 type_modifier,
337 format,
338 });
339 }
340
341 if pos != payload.len() {
342 return Err("RowDescription has trailing bytes".to_string());
343 }
344
345 Ok(BackendMessage::RowDescription(fields))
346 }
347
348 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
349 if payload.len() < 2 {
350 return Err("DataRow payload too short".to_string());
351 }
352
353 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
354 if raw_count < 0 {
355 return Err(format!("DataRow invalid column count: {}", raw_count));
356 }
357 let column_count = raw_count as usize;
358 if column_count > (payload.len() - 2) / 4 + 1 {
360 return Err(format!(
361 "DataRow claims {} columns but payload is only {} bytes",
362 column_count,
363 payload.len()
364 ));
365 }
366 let mut columns = Vec::with_capacity(column_count);
367 let mut pos = 2;
368
369 for _ in 0..column_count {
370 if pos + 4 > payload.len() {
371 return Err("DataRow truncated".to_string());
372 }
373
374 let len = i32::from_be_bytes([
375 payload[pos],
376 payload[pos + 1],
377 payload[pos + 2],
378 payload[pos + 3],
379 ]);
380 pos += 4;
381
382 if len == -1 {
383 columns.push(None);
385 } else {
386 if len < -1 {
387 return Err(format!("DataRow invalid column length: {}", len));
388 }
389 let len = len as usize;
390 if len > payload.len().saturating_sub(pos) {
391 return Err("DataRow column data truncated".to_string());
392 }
393 let data = payload[pos..pos + len].to_vec();
394 pos += len;
395 columns.push(Some(data));
396 }
397 }
398
399 if pos != payload.len() {
400 return Err("DataRow has trailing bytes".to_string());
401 }
402
403 Ok(BackendMessage::DataRow(columns))
404 }
405
406 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
407 if payload.last().copied() != Some(0) {
408 return Err("CommandComplete missing null terminator".to_string());
409 }
410 let tag_bytes = &payload[..payload.len() - 1];
411 if tag_bytes.contains(&0) {
412 return Err("CommandComplete contains interior null byte".to_string());
413 }
414 let tag = String::from_utf8_lossy(tag_bytes).to_string();
415 Ok(BackendMessage::CommandComplete(tag))
416 }
417
418 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
419 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
420 payload,
421 )?))
422 }
423
424 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
425 if payload.last().copied() != Some(0) {
426 return Err("ErrorResponse missing final terminator".to_string());
427 }
428 let mut fields = ErrorFields::default();
429 let mut i = 0;
430 while i < payload.len() && payload[i] != 0 {
431 let field_type = payload[i];
432 i += 1;
433 let end = payload[i..]
434 .iter()
435 .position(|&b| b == 0)
436 .map(|p| p + i)
437 .ok_or("ErrorResponse field missing null terminator")?;
438 let value = String::from_utf8_lossy(&payload[i..end]).to_string();
439 i = end + 1;
440
441 match field_type {
442 b'S' => fields.severity = value,
443 b'C' => fields.code = value,
444 b'M' => fields.message = value,
445 b'D' => fields.detail = Some(value),
446 b'H' => fields.hint = Some(value),
447 _ => {}
448 }
449 }
450 if i + 1 != payload.len() {
451 return Err("ErrorResponse has trailing bytes after terminator".to_string());
452 }
453 Ok(fields)
454 }
455
456 fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
457 if payload.len() < 2 {
458 return Err("ParameterDescription payload too short".to_string());
459 }
460 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
461 if raw_count < 0 {
462 return Err(format!("ParameterDescription invalid count: {}", raw_count));
463 }
464 let count = raw_count as usize;
465 let expected_len = 2 + count * 4;
466 if payload.len() < expected_len {
467 return Err(format!(
468 "ParameterDescription truncated: expected {} bytes, got {}",
469 expected_len,
470 payload.len()
471 ));
472 }
473 let mut oids = Vec::with_capacity(count);
474 let mut pos = 2;
475 for _ in 0..count {
476 oids.push(u32::from_be_bytes([
477 payload[pos],
478 payload[pos + 1],
479 payload[pos + 2],
480 payload[pos + 3],
481 ]));
482 pos += 4;
483 }
484 if pos != payload.len() {
485 return Err("ParameterDescription has trailing bytes".to_string());
486 }
487 Ok(BackendMessage::ParameterDescription(oids))
488 }
489
490 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
491 if payload.len() < 3 {
492 return Err("CopyInResponse payload too short".to_string());
493 }
494 let format = payload[0];
495 if format > 1 {
496 return Err(format!(
497 "CopyInResponse invalid overall format code: {}",
498 format
499 ));
500 }
501 let num_columns = if payload.len() >= 3 {
502 let raw = i16::from_be_bytes([payload[1], payload[2]]);
503 if raw < 0 {
504 return Err(format!(
505 "CopyInResponse invalid negative column count: {}",
506 raw
507 ));
508 }
509 raw as usize
510 } else {
511 0
512 };
513 let mut column_formats = Vec::with_capacity(num_columns);
514 let mut pos = 3usize;
515 for _ in 0..num_columns {
516 if pos + 2 > payload.len() {
517 return Err("CopyInResponse truncated column format list".to_string());
518 }
519 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
520 if !(0..=1).contains(&raw) {
521 return Err(format!("CopyInResponse invalid format code: {}", raw));
522 }
523 column_formats.push(raw as u8);
524 pos += 2;
525 }
526 if pos != payload.len() {
527 return Err("CopyInResponse has trailing bytes".to_string());
528 }
529 Ok(BackendMessage::CopyInResponse {
530 format,
531 column_formats,
532 })
533 }
534
535 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
536 if payload.len() < 3 {
537 return Err("CopyOutResponse payload too short".to_string());
538 }
539 let format = payload[0];
540 if format > 1 {
541 return Err(format!(
542 "CopyOutResponse invalid overall format code: {}",
543 format
544 ));
545 }
546 let num_columns = if payload.len() >= 3 {
547 let raw = i16::from_be_bytes([payload[1], payload[2]]);
548 if raw < 0 {
549 return Err(format!(
550 "CopyOutResponse invalid negative column count: {}",
551 raw
552 ));
553 }
554 raw as usize
555 } else {
556 0
557 };
558 let mut column_formats = Vec::with_capacity(num_columns);
559 let mut pos = 3usize;
560 for _ in 0..num_columns {
561 if pos + 2 > payload.len() {
562 return Err("CopyOutResponse truncated column format list".to_string());
563 }
564 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
565 if !(0..=1).contains(&raw) {
566 return Err(format!("CopyOutResponse invalid format code: {}", raw));
567 }
568 column_formats.push(raw as u8);
569 pos += 2;
570 }
571 if pos != payload.len() {
572 return Err("CopyOutResponse has trailing bytes".to_string());
573 }
574 Ok(BackendMessage::CopyOutResponse {
575 format,
576 column_formats,
577 })
578 }
579
580 fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
581 if payload.len() < 3 {
582 return Err("CopyBothResponse payload too short".to_string());
583 }
584 let format = payload[0];
585 if format > 1 {
586 return Err(format!(
587 "CopyBothResponse invalid overall format code: {}",
588 format
589 ));
590 }
591 let num_columns = if payload.len() >= 3 {
592 let raw = i16::from_be_bytes([payload[1], payload[2]]);
593 if raw < 0 {
594 return Err(format!(
595 "CopyBothResponse invalid negative column count: {}",
596 raw
597 ));
598 }
599 raw as usize
600 } else {
601 0
602 };
603 let mut column_formats = Vec::with_capacity(num_columns);
604 let mut pos = 3usize;
605 for _ in 0..num_columns {
606 if pos + 2 > payload.len() {
607 return Err("CopyBothResponse truncated column format list".to_string());
608 }
609 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
610 if !(0..=1).contains(&raw) {
611 return Err(format!("CopyBothResponse invalid format code: {}", raw));
612 }
613 column_formats.push(raw as u8);
614 pos += 2;
615 }
616 if pos != payload.len() {
617 return Err("CopyBothResponse has trailing bytes".to_string());
618 }
619 Ok(BackendMessage::CopyBothResponse {
620 format,
621 column_formats,
622 })
623 }
624
625 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
626 if payload.len() < 6 {
627 return Err("NotificationResponse too short".to_string());
629 }
630 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
631
632 let mut i = 4;
634 let remaining = payload.get(i..).unwrap_or(&[]);
635 let channel_end = remaining
636 .iter()
637 .position(|&b| b == 0)
638 .ok_or("NotificationResponse: missing channel null terminator")?;
639 let channel = String::from_utf8_lossy(&remaining[..channel_end]).to_string();
640 i += channel_end + 1;
641
642 let remaining = payload.get(i..).unwrap_or(&[]);
644 let payload_end = remaining
645 .iter()
646 .position(|&b| b == 0)
647 .ok_or("NotificationResponse: missing payload null terminator")?;
648 let notification_payload = String::from_utf8_lossy(&remaining[..payload_end]).to_string();
649 if i + payload_end + 1 != payload.len() {
650 return Err("NotificationResponse has trailing bytes".to_string());
651 }
652
653 Ok(BackendMessage::NotificationResponse {
654 process_id,
655 channel,
656 payload: notification_payload,
657 })
658 }
659}