1use super::types::*;
4
5pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11impl BackendMessage {
12 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
14 if buf.len() < 5 {
15 return Err("Buffer too short".to_string());
16 }
17
18 let msg_type = buf[0];
19 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
20
21 if len < 4 {
24 return Err(format!("Invalid message length: {} (minimum is 4)", len));
25 }
26 if len > MAX_BACKEND_FRAME_LEN {
27 return Err(format!(
28 "Message too large: {} bytes (max {})",
29 len, MAX_BACKEND_FRAME_LEN
30 ));
31 }
32
33 let frame_len = len
34 .checked_add(1)
35 .ok_or_else(|| "Message length overflow".to_string())?;
36
37 if buf.len() < frame_len {
38 return Err("Incomplete message".to_string());
39 }
40
41 let payload = &buf[5..frame_len];
42
43 let message = match msg_type {
44 b'R' => Self::decode_auth(payload)?,
45 b'S' => Self::decode_parameter_status(payload)?,
46 b'K' => Self::decode_backend_key(payload)?,
47 b'Z' => Self::decode_ready_for_query(payload)?,
48 b'T' => Self::decode_row_description(payload)?,
49 b'D' => Self::decode_data_row(payload)?,
50 b'C' => Self::decode_command_complete(payload)?,
51 b'E' => Self::decode_error_response(payload)?,
52 b'1' => {
53 if !payload.is_empty() {
54 return Err("ParseComplete must have empty payload".to_string());
55 }
56 BackendMessage::ParseComplete
57 }
58 b'2' => {
59 if !payload.is_empty() {
60 return Err("BindComplete must have empty payload".to_string());
61 }
62 BackendMessage::BindComplete
63 }
64 b'3' => {
65 if !payload.is_empty() {
66 return Err("CloseComplete must have empty payload".to_string());
67 }
68 BackendMessage::CloseComplete
69 }
70 b'n' => {
71 if !payload.is_empty() {
72 return Err("NoData must have empty payload".to_string());
73 }
74 BackendMessage::NoData
75 }
76 b's' => {
77 if !payload.is_empty() {
78 return Err("PortalSuspended must have empty payload".to_string());
79 }
80 BackendMessage::PortalSuspended
81 }
82 b't' => Self::decode_parameter_description(payload)?,
83 b'G' => Self::decode_copy_in_response(payload)?,
84 b'H' => Self::decode_copy_out_response(payload)?,
85 b'W' => Self::decode_copy_both_response(payload)?,
86 b'd' => BackendMessage::CopyData(payload.to_vec()),
87 b'c' => {
88 if !payload.is_empty() {
89 return Err("CopyDone must have empty payload".to_string());
90 }
91 BackendMessage::CopyDone
92 }
93 b'A' => Self::decode_notification_response(payload)?,
94 b'I' => {
95 if !payload.is_empty() {
96 return Err("EmptyQueryResponse must have empty payload".to_string());
97 }
98 BackendMessage::EmptyQueryResponse
99 }
100 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
101 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
102 };
103
104 Ok((message, frame_len))
105 }
106
107 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
108 if payload.len() < 4 {
109 return Err("Auth payload too short".to_string());
110 }
111 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
112 match auth_type {
113 0 => {
114 if payload.len() != 4 {
115 return Err(format!(
116 "AuthenticationOk invalid payload length: {}",
117 payload.len()
118 ));
119 }
120 Ok(BackendMessage::AuthenticationOk)
121 }
122 2 => {
123 if payload.len() != 4 {
124 return Err(format!(
125 "AuthenticationKerberosV5 invalid payload length: {}",
126 payload.len()
127 ));
128 }
129 Ok(BackendMessage::AuthenticationKerberosV5)
130 }
131 3 => {
132 if payload.len() != 4 {
133 return Err(format!(
134 "AuthenticationCleartextPassword invalid payload length: {}",
135 payload.len()
136 ));
137 }
138 Ok(BackendMessage::AuthenticationCleartextPassword)
139 }
140 5 => {
141 if payload.len() != 8 {
142 return Err("MD5 auth payload too short (need salt)".to_string());
143 }
144 let mut salt = [0u8; 4];
145 salt.copy_from_slice(&payload[4..8]);
146 Ok(BackendMessage::AuthenticationMD5Password(salt))
147 }
148 7 => {
149 if payload.len() != 4 {
150 return Err(format!(
151 "AuthenticationGSS invalid payload length: {}",
152 payload.len()
153 ));
154 }
155 Ok(BackendMessage::AuthenticationGSS)
156 }
157 8 => Ok(BackendMessage::AuthenticationGSSContinue(
158 payload[4..].to_vec(),
159 )),
160 9 => {
161 if payload.len() != 4 {
162 return Err(format!(
163 "AuthenticationSSPI invalid payload length: {}",
164 payload.len()
165 ));
166 }
167 Ok(BackendMessage::AuthenticationSSPI)
168 }
169 10 => {
170 let mut mechanisms = Vec::new();
172 let mut pos = 4;
173 while pos < payload.len() {
174 if payload[pos] == 0 {
175 break; }
177 let end = payload[pos..]
178 .iter()
179 .position(|&b| b == 0)
180 .map(|p| pos + p)
181 .ok_or("SASL mechanism list missing null terminator")?;
182 mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
183 pos = end + 1;
184 }
185 if pos >= payload.len() {
186 return Err("SASL mechanism list missing final terminator".to_string());
187 }
188 if pos + 1 != payload.len() {
189 return Err("SASL mechanism list has trailing bytes".to_string());
190 }
191 if mechanisms.is_empty() {
192 return Err("SASL mechanism list is empty".to_string());
193 }
194 Ok(BackendMessage::AuthenticationSASL(mechanisms))
195 }
196 11 => {
197 Ok(BackendMessage::AuthenticationSASLContinue(
199 payload[4..].to_vec(),
200 ))
201 }
202 12 => {
203 Ok(BackendMessage::AuthenticationSASLFinal(
205 payload[4..].to_vec(),
206 ))
207 }
208 _ => Err(format!("Unknown auth type: {}", auth_type)),
209 }
210 }
211
212 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
213 let name_end = payload
214 .iter()
215 .position(|&b| b == 0)
216 .ok_or("ParameterStatus missing name terminator")?;
217 let value_start = name_end + 1;
218 if value_start > payload.len() {
219 return Err("ParameterStatus missing value".to_string());
220 }
221 let value_end_rel = payload[value_start..]
222 .iter()
223 .position(|&b| b == 0)
224 .ok_or("ParameterStatus missing value terminator")?;
225 let value_end = value_start + value_end_rel;
226 if value_end + 1 != payload.len() {
227 return Err("ParameterStatus has trailing bytes".to_string());
228 }
229 Ok(BackendMessage::ParameterStatus {
230 name: String::from_utf8_lossy(&payload[..name_end]).to_string(),
231 value: String::from_utf8_lossy(&payload[value_start..value_end]).to_string(),
232 })
233 }
234
235 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
236 if payload.len() != 8 {
237 return Err("BackendKeyData payload too short".to_string());
238 }
239 Ok(BackendMessage::BackendKeyData {
240 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
241 secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
242 })
243 }
244
245 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
246 if payload.len() != 1 {
247 return Err("ReadyForQuery payload empty".to_string());
248 }
249 let status = match payload[0] {
250 b'I' => TransactionStatus::Idle,
251 b'T' => TransactionStatus::InBlock,
252 b'E' => TransactionStatus::Failed,
253 _ => return Err("Unknown transaction status".to_string()),
254 };
255 Ok(BackendMessage::ReadyForQuery(status))
256 }
257
258 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
259 if payload.len() < 2 {
260 return Err("RowDescription payload too short".to_string());
261 }
262
263 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
264 if raw_count < 0 {
265 return Err(format!("RowDescription invalid field count: {}", raw_count));
266 }
267 let field_count = raw_count as usize;
268 let mut fields = Vec::with_capacity(field_count);
269 let mut pos = 2;
270
271 for _ in 0..field_count {
272 let name_end = payload[pos..]
274 .iter()
275 .position(|&b| b == 0)
276 .ok_or("Missing null terminator in field name")?;
277 let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
278 pos += name_end + 1; if pos + 18 > payload.len() {
282 return Err("RowDescription field truncated".to_string());
283 }
284
285 let table_oid = u32::from_be_bytes([
286 payload[pos],
287 payload[pos + 1],
288 payload[pos + 2],
289 payload[pos + 3],
290 ]);
291 pos += 4;
292
293 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
294 pos += 2;
295
296 let type_oid = u32::from_be_bytes([
297 payload[pos],
298 payload[pos + 1],
299 payload[pos + 2],
300 payload[pos + 3],
301 ]);
302 pos += 4;
303
304 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
305 pos += 2;
306
307 let type_modifier = i32::from_be_bytes([
308 payload[pos],
309 payload[pos + 1],
310 payload[pos + 2],
311 payload[pos + 3],
312 ]);
313 pos += 4;
314
315 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
316 if !(0..=1).contains(&format) {
317 return Err(format!("RowDescription invalid format code: {}", format));
318 }
319 pos += 2;
320
321 fields.push(FieldDescription {
322 name,
323 table_oid,
324 column_attr,
325 type_oid,
326 type_size,
327 type_modifier,
328 format,
329 });
330 }
331
332 if pos != payload.len() {
333 return Err("RowDescription has trailing bytes".to_string());
334 }
335
336 Ok(BackendMessage::RowDescription(fields))
337 }
338
339 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
340 if payload.len() < 2 {
341 return Err("DataRow payload too short".to_string());
342 }
343
344 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
345 if raw_count < 0 {
346 return Err(format!("DataRow invalid column count: {}", raw_count));
347 }
348 let column_count = raw_count as usize;
349 if column_count > (payload.len() - 2) / 4 + 1 {
351 return Err(format!(
352 "DataRow claims {} columns but payload is only {} bytes",
353 column_count,
354 payload.len()
355 ));
356 }
357 let mut columns = Vec::with_capacity(column_count);
358 let mut pos = 2;
359
360 for _ in 0..column_count {
361 if pos + 4 > payload.len() {
362 return Err("DataRow truncated".to_string());
363 }
364
365 let len = i32::from_be_bytes([
366 payload[pos],
367 payload[pos + 1],
368 payload[pos + 2],
369 payload[pos + 3],
370 ]);
371 pos += 4;
372
373 if len == -1 {
374 columns.push(None);
376 } else {
377 if len < -1 {
378 return Err(format!("DataRow invalid column length: {}", len));
379 }
380 let len = len as usize;
381 if len > payload.len().saturating_sub(pos) {
382 return Err("DataRow column data truncated".to_string());
383 }
384 let data = payload[pos..pos + len].to_vec();
385 pos += len;
386 columns.push(Some(data));
387 }
388 }
389
390 if pos != payload.len() {
391 return Err("DataRow has trailing bytes".to_string());
392 }
393
394 Ok(BackendMessage::DataRow(columns))
395 }
396
397 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
398 if payload.last().copied() != Some(0) {
399 return Err("CommandComplete missing null terminator".to_string());
400 }
401 let tag_bytes = &payload[..payload.len() - 1];
402 if tag_bytes.contains(&0) {
403 return Err("CommandComplete contains interior null byte".to_string());
404 }
405 let tag = String::from_utf8_lossy(tag_bytes).to_string();
406 Ok(BackendMessage::CommandComplete(tag))
407 }
408
409 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
410 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
411 payload,
412 )?))
413 }
414
415 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
416 if payload.last().copied() != Some(0) {
417 return Err("ErrorResponse missing final terminator".to_string());
418 }
419 let mut fields = ErrorFields::default();
420 let mut i = 0;
421 while i < payload.len() && payload[i] != 0 {
422 let field_type = payload[i];
423 i += 1;
424 let end = payload[i..]
425 .iter()
426 .position(|&b| b == 0)
427 .map(|p| p + i)
428 .ok_or("ErrorResponse field missing null terminator")?;
429 let value = String::from_utf8_lossy(&payload[i..end]).to_string();
430 i = end + 1;
431
432 match field_type {
433 b'S' => fields.severity = value,
434 b'C' => fields.code = value,
435 b'M' => fields.message = value,
436 b'D' => fields.detail = Some(value),
437 b'H' => fields.hint = Some(value),
438 _ => {}
439 }
440 }
441 if i + 1 != payload.len() {
442 return Err("ErrorResponse has trailing bytes after terminator".to_string());
443 }
444 Ok(fields)
445 }
446
447 fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
448 if payload.len() < 2 {
449 return Err("ParameterDescription payload too short".to_string());
450 }
451 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
452 if raw_count < 0 {
453 return Err(format!("ParameterDescription invalid count: {}", raw_count));
454 }
455 let count = raw_count as usize;
456 let expected_len = 2 + count * 4;
457 if payload.len() < expected_len {
458 return Err(format!(
459 "ParameterDescription truncated: expected {} bytes, got {}",
460 expected_len,
461 payload.len()
462 ));
463 }
464 let mut oids = Vec::with_capacity(count);
465 let mut pos = 2;
466 for _ in 0..count {
467 oids.push(u32::from_be_bytes([
468 payload[pos],
469 payload[pos + 1],
470 payload[pos + 2],
471 payload[pos + 3],
472 ]));
473 pos += 4;
474 }
475 if pos != payload.len() {
476 return Err("ParameterDescription has trailing bytes".to_string());
477 }
478 Ok(BackendMessage::ParameterDescription(oids))
479 }
480
481 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
482 if payload.len() < 3 {
483 return Err("CopyInResponse payload too short".to_string());
484 }
485 let format = payload[0];
486 if format > 1 {
487 return Err(format!(
488 "CopyInResponse invalid overall format code: {}",
489 format
490 ));
491 }
492 let num_columns = if payload.len() >= 3 {
493 let raw = i16::from_be_bytes([payload[1], payload[2]]);
494 if raw < 0 {
495 return Err(format!(
496 "CopyInResponse invalid negative column count: {}",
497 raw
498 ));
499 }
500 raw as usize
501 } else {
502 0
503 };
504 let mut column_formats = Vec::with_capacity(num_columns);
505 let mut pos = 3usize;
506 for _ in 0..num_columns {
507 if pos + 2 > payload.len() {
508 return Err("CopyInResponse truncated column format list".to_string());
509 }
510 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
511 if !(0..=1).contains(&raw) {
512 return Err(format!("CopyInResponse invalid format code: {}", raw));
513 }
514 column_formats.push(raw as u8);
515 pos += 2;
516 }
517 if pos != payload.len() {
518 return Err("CopyInResponse has trailing bytes".to_string());
519 }
520 Ok(BackendMessage::CopyInResponse {
521 format,
522 column_formats,
523 })
524 }
525
526 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
527 if payload.len() < 3 {
528 return Err("CopyOutResponse payload too short".to_string());
529 }
530 let format = payload[0];
531 if format > 1 {
532 return Err(format!(
533 "CopyOutResponse invalid overall format code: {}",
534 format
535 ));
536 }
537 let num_columns = if payload.len() >= 3 {
538 let raw = i16::from_be_bytes([payload[1], payload[2]]);
539 if raw < 0 {
540 return Err(format!(
541 "CopyOutResponse invalid negative column count: {}",
542 raw
543 ));
544 }
545 raw as usize
546 } else {
547 0
548 };
549 let mut column_formats = Vec::with_capacity(num_columns);
550 let mut pos = 3usize;
551 for _ in 0..num_columns {
552 if pos + 2 > payload.len() {
553 return Err("CopyOutResponse truncated column format list".to_string());
554 }
555 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
556 if !(0..=1).contains(&raw) {
557 return Err(format!("CopyOutResponse invalid format code: {}", raw));
558 }
559 column_formats.push(raw as u8);
560 pos += 2;
561 }
562 if pos != payload.len() {
563 return Err("CopyOutResponse has trailing bytes".to_string());
564 }
565 Ok(BackendMessage::CopyOutResponse {
566 format,
567 column_formats,
568 })
569 }
570
571 fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
572 if payload.len() < 3 {
573 return Err("CopyBothResponse payload too short".to_string());
574 }
575 let format = payload[0];
576 if format > 1 {
577 return Err(format!(
578 "CopyBothResponse invalid overall format code: {}",
579 format
580 ));
581 }
582 let num_columns = if payload.len() >= 3 {
583 let raw = i16::from_be_bytes([payload[1], payload[2]]);
584 if raw < 0 {
585 return Err(format!(
586 "CopyBothResponse invalid negative column count: {}",
587 raw
588 ));
589 }
590 raw as usize
591 } else {
592 0
593 };
594 let mut column_formats = Vec::with_capacity(num_columns);
595 let mut pos = 3usize;
596 for _ in 0..num_columns {
597 if pos + 2 > payload.len() {
598 return Err("CopyBothResponse truncated column format list".to_string());
599 }
600 let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
601 if !(0..=1).contains(&raw) {
602 return Err(format!("CopyBothResponse invalid format code: {}", raw));
603 }
604 column_formats.push(raw as u8);
605 pos += 2;
606 }
607 if pos != payload.len() {
608 return Err("CopyBothResponse has trailing bytes".to_string());
609 }
610 Ok(BackendMessage::CopyBothResponse {
611 format,
612 column_formats,
613 })
614 }
615
616 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
617 if payload.len() < 6 {
618 return Err("NotificationResponse too short".to_string());
620 }
621 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
622
623 let mut i = 4;
625 let remaining = payload.get(i..).unwrap_or(&[]);
626 let channel_end = remaining
627 .iter()
628 .position(|&b| b == 0)
629 .ok_or("NotificationResponse: missing channel null terminator")?;
630 let channel = String::from_utf8_lossy(&remaining[..channel_end]).to_string();
631 i += channel_end + 1;
632
633 let remaining = payload.get(i..).unwrap_or(&[]);
635 let payload_end = remaining
636 .iter()
637 .position(|&b| b == 0)
638 .ok_or("NotificationResponse: missing payload null terminator")?;
639 let notification_payload = String::from_utf8_lossy(&remaining[..payload_end]).to_string();
640 if i + payload_end + 1 != payload.len() {
641 return Err("NotificationResponse has trailing bytes".to_string());
642 }
643
644 Ok(BackendMessage::NotificationResponse {
645 process_id,
646 channel,
647 payload: notification_payload,
648 })
649 }
650}