1#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9 Startup {
11 user: String,
13 database: String,
15 },
16 PasswordMessage(String),
18 Query(String),
20 Parse {
22 name: String,
24 query: String,
26 param_types: Vec<u32>,
28 },
29 Bind {
31 portal: String,
33 statement: String,
35 params: Vec<Option<Vec<u8>>>,
37 },
38 Execute {
40 portal: String,
42 max_rows: i32,
44 },
45 Sync,
47 Terminate,
49 SASLInitialResponse {
51 mechanism: String,
53 data: Vec<u8>,
55 },
56 SASLResponse(Vec<u8>),
58 GSSResponse(Vec<u8>),
60 CopyFail(String),
62 Close {
64 is_portal: bool,
66 name: String,
68 },
69}
70
71#[derive(Debug, Clone)]
73pub enum BackendMessage {
74 AuthenticationOk,
77 AuthenticationCleartextPassword,
79 AuthenticationMD5Password([u8; 4]),
81 AuthenticationKerberosV5,
83 AuthenticationGSS,
85 AuthenticationGSSContinue(Vec<u8>),
87 AuthenticationSSPI,
89 AuthenticationSASL(Vec<String>),
91 AuthenticationSASLContinue(Vec<u8>),
93 AuthenticationSASLFinal(Vec<u8>),
95 ParameterStatus {
97 name: String,
99 value: String,
101 },
102 BackendKeyData {
104 process_id: i32,
106 secret_key: i32,
108 },
109 ReadyForQuery(TransactionStatus),
111 RowDescription(Vec<FieldDescription>),
113 DataRow(Vec<Option<Vec<u8>>>),
115 CommandComplete(String),
117 ErrorResponse(ErrorFields),
119 ParseComplete,
121 BindComplete,
123 NoData,
125 PortalSuspended,
127 CopyInResponse {
129 format: u8,
131 column_formats: Vec<u8>,
133 },
134 CopyOutResponse {
136 format: u8,
138 column_formats: Vec<u8>,
140 },
141 CopyData(Vec<u8>),
143 CopyDone,
145 NotificationResponse {
147 process_id: i32,
149 channel: String,
151 payload: String,
153 },
154 EmptyQueryResponse,
156 NoticeResponse(ErrorFields),
158 ParameterDescription(Vec<u32>),
161 CloseComplete,
163}
164
165#[derive(Debug, Clone, Copy)]
167pub enum TransactionStatus {
168 Idle,
170 InBlock,
172 Failed,
174}
175
176#[derive(Debug, Clone)]
178pub struct FieldDescription {
179 pub name: String,
181 pub table_oid: u32,
183 pub column_attr: i16,
185 pub type_oid: u32,
187 pub type_size: i16,
189 pub type_modifier: i32,
191 pub format: i16,
193}
194
195#[derive(Debug, Clone, Default)]
197pub struct ErrorFields {
198 pub severity: String,
200 pub code: String,
202 pub message: String,
204 pub detail: Option<String>,
206 pub hint: Option<String>,
208}
209
210impl FrontendMessage {
211 pub fn encode(&self) -> Vec<u8> {
213 match self {
214 FrontendMessage::Startup { user, database } => {
215 let mut buf = Vec::new();
216 buf.extend_from_slice(&196608i32.to_be_bytes());
218 buf.extend_from_slice(b"user\0");
220 buf.extend_from_slice(user.as_bytes());
221 buf.push(0);
222 buf.extend_from_slice(b"database\0");
223 buf.extend_from_slice(database.as_bytes());
224 buf.push(0);
225 buf.push(0); let len = (buf.len() + 4) as i32;
229 let mut result = len.to_be_bytes().to_vec();
230 result.extend(buf);
231 result
232 }
233 FrontendMessage::Query(sql) => {
234 let mut buf = Vec::new();
235 buf.push(b'Q');
236 let content = format!("{}\0", sql);
237 let len = (content.len() + 4) as i32;
238 buf.extend_from_slice(&len.to_be_bytes());
239 buf.extend_from_slice(content.as_bytes());
240 buf
241 }
242 FrontendMessage::Terminate => {
243 vec![b'X', 0, 0, 0, 4]
244 }
245 FrontendMessage::SASLInitialResponse { mechanism, data } => {
246 let mut buf = Vec::new();
247 buf.push(b'p'); let mut content = Vec::new();
250 content.extend_from_slice(mechanism.as_bytes());
251 content.push(0); content.extend_from_slice(&(data.len() as i32).to_be_bytes());
253 content.extend_from_slice(data);
254
255 let len = (content.len() + 4) as i32;
256 buf.extend_from_slice(&len.to_be_bytes());
257 buf.extend_from_slice(&content);
258 buf
259 }
260 FrontendMessage::SASLResponse(data) => {
261 let mut buf = Vec::new();
262 buf.push(b'p');
263
264 let len = (data.len() + 4) as i32;
265 buf.extend_from_slice(&len.to_be_bytes());
266 buf.extend_from_slice(data);
267 buf
268 }
269 FrontendMessage::GSSResponse(data) => {
270 let mut buf = Vec::new();
271 buf.push(b'p');
272
273 let len = (data.len() + 4) as i32;
274 buf.extend_from_slice(&len.to_be_bytes());
275 buf.extend_from_slice(data);
276 buf
277 }
278 FrontendMessage::PasswordMessage(password) => {
279 let mut buf = Vec::new();
280 buf.push(b'p');
281 let content = format!("{}\0", password);
282 let len = (content.len() + 4) as i32;
283 buf.extend_from_slice(&len.to_be_bytes());
284 buf.extend_from_slice(content.as_bytes());
285 buf
286 }
287 FrontendMessage::Parse {
288 name,
289 query,
290 param_types,
291 } => {
292 let mut buf = Vec::new();
293 buf.push(b'P');
294
295 let mut content = Vec::new();
296 content.extend_from_slice(name.as_bytes());
297 content.push(0);
298 content.extend_from_slice(query.as_bytes());
299 content.push(0);
300 content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
301 for oid in param_types {
302 content.extend_from_slice(&oid.to_be_bytes());
303 }
304
305 let len = (content.len() + 4) as i32;
306 buf.extend_from_slice(&len.to_be_bytes());
307 buf.extend_from_slice(&content);
308 buf
309 }
310 FrontendMessage::Bind {
311 portal,
312 statement,
313 params,
314 } => {
315 let mut buf = Vec::new();
316 buf.push(b'B');
317
318 let mut content = Vec::new();
319 content.extend_from_slice(portal.as_bytes());
320 content.push(0);
321 content.extend_from_slice(statement.as_bytes());
322 content.push(0);
323 content.extend_from_slice(&0i16.to_be_bytes());
325 content.extend_from_slice(&(params.len() as i16).to_be_bytes());
327 for param in params {
328 match param {
329 Some(data) => {
330 content.extend_from_slice(&(data.len() as i32).to_be_bytes());
331 content.extend_from_slice(data);
332 }
333 None => content.extend_from_slice(&(-1i32).to_be_bytes()),
334 }
335 }
336 content.extend_from_slice(&0i16.to_be_bytes());
338
339 let len = (content.len() + 4) as i32;
340 buf.extend_from_slice(&len.to_be_bytes());
341 buf.extend_from_slice(&content);
342 buf
343 }
344 FrontendMessage::Execute { portal, max_rows } => {
345 let mut buf = Vec::new();
346 buf.push(b'E');
347
348 let mut content = Vec::new();
349 content.extend_from_slice(portal.as_bytes());
350 content.push(0);
351 content.extend_from_slice(&max_rows.to_be_bytes());
352
353 let len = (content.len() + 4) as i32;
354 buf.extend_from_slice(&len.to_be_bytes());
355 buf.extend_from_slice(&content);
356 buf
357 }
358 FrontendMessage::Sync => {
359 vec![b'S', 0, 0, 0, 4]
360 }
361 FrontendMessage::CopyFail(msg) => {
362 let mut buf = Vec::new();
363 buf.push(b'f');
364 let content = format!("{}\0", msg);
365 let len = (content.len() + 4) as i32;
366 buf.extend_from_slice(&len.to_be_bytes());
367 buf.extend_from_slice(content.as_bytes());
368 buf
369 }
370 FrontendMessage::Close { is_portal, name } => {
371 let mut buf = Vec::new();
372 buf.push(b'C');
373 let type_byte = if *is_portal { b'P' } else { b'S' };
374 let mut content = vec![type_byte];
375 content.extend_from_slice(name.as_bytes());
376 content.push(0);
377 let len = (content.len() + 4) as i32;
378 buf.extend_from_slice(&len.to_be_bytes());
379 buf.extend_from_slice(&content);
380 buf
381 }
382 }
383 }
384}
385
386impl BackendMessage {
387 pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
389 if buf.len() < 5 {
390 return Err("Buffer too short".to_string());
391 }
392
393 let msg_type = buf[0];
394 let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
395
396 if len < 4 {
399 return Err(format!("Invalid message length: {} (minimum is 4)", len));
400 }
401
402 if buf.len() < len + 1 {
403 return Err("Incomplete message".to_string());
404 }
405
406 let payload = &buf[5..len + 1];
407
408 let message = match msg_type {
409 b'R' => Self::decode_auth(payload)?,
410 b'S' => Self::decode_parameter_status(payload)?,
411 b'K' => Self::decode_backend_key(payload)?,
412 b'Z' => Self::decode_ready_for_query(payload)?,
413 b'T' => Self::decode_row_description(payload)?,
414 b'D' => Self::decode_data_row(payload)?,
415 b'C' => Self::decode_command_complete(payload)?,
416 b'E' => Self::decode_error_response(payload)?,
417 b'1' => BackendMessage::ParseComplete,
418 b'2' => BackendMessage::BindComplete,
419 b'3' => BackendMessage::CloseComplete,
420 b'n' => BackendMessage::NoData,
421 b's' => BackendMessage::PortalSuspended,
422 b't' => Self::decode_parameter_description(payload)?,
423 b'G' => Self::decode_copy_in_response(payload)?,
424 b'H' => Self::decode_copy_out_response(payload)?,
425 b'd' => BackendMessage::CopyData(payload.to_vec()),
426 b'c' => BackendMessage::CopyDone,
427 b'A' => Self::decode_notification_response(payload)?,
428 b'I' => BackendMessage::EmptyQueryResponse,
429 b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
430 _ => return Err(format!("Unknown message type: {}", msg_type as char)),
431 };
432
433 Ok((message, len + 1))
434 }
435
436 fn decode_auth(payload: &[u8]) -> Result<Self, String> {
437 if payload.len() < 4 {
438 return Err("Auth payload too short".to_string());
439 }
440 let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
441 match auth_type {
442 0 => Ok(BackendMessage::AuthenticationOk),
443 2 => Ok(BackendMessage::AuthenticationKerberosV5),
444 3 => Ok(BackendMessage::AuthenticationCleartextPassword),
445 5 => {
446 if payload.len() < 8 {
447 return Err("MD5 auth payload too short (need salt)".to_string());
448 }
449 let salt: [u8; 4] = payload[4..8]
451 .try_into()
452 .expect("salt slice is exactly 4 bytes");
453 Ok(BackendMessage::AuthenticationMD5Password(salt))
454 }
455 7 => Ok(BackendMessage::AuthenticationGSS),
456 8 => Ok(BackendMessage::AuthenticationGSSContinue(
457 payload[4..].to_vec(),
458 )),
459 9 => Ok(BackendMessage::AuthenticationSSPI),
460 10 => {
461 let mut mechanisms = Vec::new();
463 let mut pos = 4;
464 while pos < payload.len() && payload[pos] != 0 {
465 let end = payload[pos..]
466 .iter()
467 .position(|&b| b == 0)
468 .map(|p| pos + p)
469 .unwrap_or(payload.len());
470 mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
471 pos = end + 1;
472 }
473 Ok(BackendMessage::AuthenticationSASL(mechanisms))
474 }
475 11 => {
476 Ok(BackendMessage::AuthenticationSASLContinue(
478 payload[4..].to_vec(),
479 ))
480 }
481 12 => {
482 Ok(BackendMessage::AuthenticationSASLFinal(
484 payload[4..].to_vec(),
485 ))
486 }
487 _ => Err(format!("Unknown auth type: {}", auth_type)),
488 }
489 }
490
491 fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
492 let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
493 let empty: &[u8] = b"";
494 Ok(BackendMessage::ParameterStatus {
495 name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
496 value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
497 })
498 }
499
500 fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
501 if payload.len() < 8 {
502 return Err("BackendKeyData payload too short".to_string());
503 }
504 Ok(BackendMessage::BackendKeyData {
505 process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
506 secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
507 })
508 }
509
510 fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
511 if payload.is_empty() {
512 return Err("ReadyForQuery payload empty".to_string());
513 }
514 let status = match payload[0] {
515 b'I' => TransactionStatus::Idle,
516 b'T' => TransactionStatus::InBlock,
517 b'E' => TransactionStatus::Failed,
518 _ => return Err("Unknown transaction status".to_string()),
519 };
520 Ok(BackendMessage::ReadyForQuery(status))
521 }
522
523 fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
524 if payload.len() < 2 {
525 return Err("RowDescription payload too short".to_string());
526 }
527
528 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
529 if raw_count < 0 {
530 return Err(format!("RowDescription invalid field count: {}", raw_count));
531 }
532 let field_count = raw_count as usize;
533 let mut fields = Vec::with_capacity(field_count);
534 let mut pos = 2;
535
536 for _ in 0..field_count {
537 let name_end = payload[pos..]
539 .iter()
540 .position(|&b| b == 0)
541 .ok_or("Missing null terminator in field name")?;
542 let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
543 pos += name_end + 1; if pos + 18 > payload.len() {
547 return Err("RowDescription field truncated".to_string());
548 }
549
550 let table_oid = u32::from_be_bytes([
551 payload[pos],
552 payload[pos + 1],
553 payload[pos + 2],
554 payload[pos + 3],
555 ]);
556 pos += 4;
557
558 let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
559 pos += 2;
560
561 let type_oid = u32::from_be_bytes([
562 payload[pos],
563 payload[pos + 1],
564 payload[pos + 2],
565 payload[pos + 3],
566 ]);
567 pos += 4;
568
569 let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
570 pos += 2;
571
572 let type_modifier = i32::from_be_bytes([
573 payload[pos],
574 payload[pos + 1],
575 payload[pos + 2],
576 payload[pos + 3],
577 ]);
578 pos += 4;
579
580 let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
581 pos += 2;
582
583 fields.push(FieldDescription {
584 name,
585 table_oid,
586 column_attr,
587 type_oid,
588 type_size,
589 type_modifier,
590 format,
591 });
592 }
593
594 Ok(BackendMessage::RowDescription(fields))
595 }
596
597 fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
598 if payload.len() < 2 {
599 return Err("DataRow payload too short".to_string());
600 }
601
602 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
603 if raw_count < 0 {
604 return Err(format!("DataRow invalid column count: {}", raw_count));
605 }
606 let column_count = raw_count as usize;
607 if column_count > (payload.len() - 2) / 4 + 1 {
609 return Err(format!(
610 "DataRow claims {} columns but payload is only {} bytes",
611 column_count,
612 payload.len()
613 ));
614 }
615 let mut columns = Vec::with_capacity(column_count);
616 let mut pos = 2;
617
618 for _ in 0..column_count {
619 if pos + 4 > payload.len() {
620 return Err("DataRow truncated".to_string());
621 }
622
623 let len = i32::from_be_bytes([
624 payload[pos],
625 payload[pos + 1],
626 payload[pos + 2],
627 payload[pos + 3],
628 ]);
629 pos += 4;
630
631 if len == -1 {
632 columns.push(None);
634 } else {
635 let len = len as usize;
636 if pos + len > payload.len() {
637 return Err("DataRow column data truncated".to_string());
638 }
639 let data = payload[pos..pos + len].to_vec();
640 pos += len;
641 columns.push(Some(data));
642 }
643 }
644
645 Ok(BackendMessage::DataRow(columns))
646 }
647
648 fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
649 let tag = String::from_utf8_lossy(payload)
650 .trim_end_matches('\0')
651 .to_string();
652 Ok(BackendMessage::CommandComplete(tag))
653 }
654
655 fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
656 Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
657 payload,
658 )?))
659 }
660
661 fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
662 let mut fields = ErrorFields::default();
663 let mut i = 0;
664 while i < payload.len() && payload[i] != 0 {
665 let field_type = payload[i];
666 i += 1;
667 let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
668 let value = String::from_utf8_lossy(&payload[i..end]).to_string();
669 i = end + 1;
670
671 match field_type {
672 b'S' => fields.severity = value,
673 b'C' => fields.code = value,
674 b'M' => fields.message = value,
675 b'D' => fields.detail = Some(value),
676 b'H' => fields.hint = Some(value),
677 _ => {}
678 }
679 }
680 Ok(fields)
681 }
682
683 fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
684 if payload.len() < 2 {
685 return Ok(BackendMessage::ParameterDescription(vec![]));
686 }
687 let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
688 if raw_count < 0 {
689 return Err(format!("ParameterDescription invalid count: {}", raw_count));
690 }
691 let count = raw_count as usize;
692 let mut oids = Vec::with_capacity(count);
693 let mut pos = 2;
694 for _ in 0..count {
695 if pos + 4 <= payload.len() {
696 oids.push(u32::from_be_bytes([
697 payload[pos],
698 payload[pos + 1],
699 payload[pos + 2],
700 payload[pos + 3],
701 ]));
702 pos += 4;
703 }
704 }
705 Ok(BackendMessage::ParameterDescription(oids))
706 }
707
708 fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
709 if payload.is_empty() {
710 return Err("Empty CopyInResponse payload".to_string());
711 }
712 let format = payload[0];
713 let num_columns = if payload.len() >= 3 {
714 let raw = i16::from_be_bytes([payload[1], payload[2]]);
715 if raw < 0 { 0usize } else { raw as usize }
716 } else {
717 0
718 };
719 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
720 payload[3..].iter().take(num_columns).copied().collect()
721 } else {
722 vec![]
723 };
724 Ok(BackendMessage::CopyInResponse {
725 format,
726 column_formats,
727 })
728 }
729
730 fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
731 if payload.is_empty() {
732 return Err("Empty CopyOutResponse payload".to_string());
733 }
734 let format = payload[0];
735 let num_columns = if payload.len() >= 3 {
736 let raw = i16::from_be_bytes([payload[1], payload[2]]);
737 if raw < 0 { 0usize } else { raw as usize }
738 } else {
739 0
740 };
741 let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
742 payload[3..].iter().take(num_columns).copied().collect()
743 } else {
744 vec![]
745 };
746 Ok(BackendMessage::CopyOutResponse {
747 format,
748 column_formats,
749 })
750 }
751
752 fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
753 if payload.len() < 6 {
754 return Err("NotificationResponse too short".to_string());
756 }
757 let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
758
759 let mut i = 4;
761 let remaining = payload.get(i..).unwrap_or(&[]);
762 let channel_end = remaining
763 .iter()
764 .position(|&b| b == 0)
765 .ok_or("NotificationResponse: missing channel null terminator")?;
766 let channel = String::from_utf8_lossy(&remaining[..channel_end]).to_string();
767 i += channel_end + 1;
768
769 let remaining = payload.get(i..).unwrap_or(&[]);
771 let payload_end = remaining
772 .iter()
773 .position(|&b| b == 0)
774 .unwrap_or(remaining.len());
775 let notification_payload = String::from_utf8_lossy(&remaining[..payload_end]).to_string();
776
777 Ok(BackendMessage::NotificationResponse {
778 process_id,
779 channel,
780 payload: notification_payload,
781 })
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 fn wire_msg(msg_type: u8, payload: &[u8]) -> Vec<u8> {
791 let len = (payload.len() + 4) as u32;
792 let mut buf = vec![msg_type];
793 buf.extend_from_slice(&len.to_be_bytes());
794 buf.extend_from_slice(payload);
795 buf
796 }
797
798 #[test]
801 fn decode_empty_buffer_returns_error() {
802 assert!(BackendMessage::decode(&[]).is_err());
803 }
804
805 #[test]
806 fn decode_too_short_buffer_returns_error() {
807 for len in 1..5 {
809 let buf = vec![b'Z'; len];
810 let result = BackendMessage::decode(&buf);
811 assert!(result.is_err(), "Expected error for {}-byte buffer", len);
812 }
813 }
814
815 #[test]
816 fn decode_incomplete_message_returns_error() {
817 let mut buf = vec![b'Z'];
819 buf.extend_from_slice(&100u32.to_be_bytes());
820 buf.extend_from_slice(&[0u8; 5]); assert!(
822 BackendMessage::decode(&buf)
823 .unwrap_err()
824 .contains("Incomplete")
825 );
826 }
827
828 #[test]
829 fn decode_unknown_message_type_returns_error() {
830 let buf = wire_msg(b'@', &[0]);
831 let result = BackendMessage::decode(&buf);
832 assert!(result.unwrap_err().contains("Unknown message type"));
833 }
834
835 #[test]
838 fn decode_auth_ok() {
839 let payload = 0i32.to_be_bytes();
840 let buf = wire_msg(b'R', &payload);
841 let (msg, consumed) = BackendMessage::decode(&buf).unwrap();
842 assert!(matches!(msg, BackendMessage::AuthenticationOk));
843 assert_eq!(consumed, buf.len());
844 }
845
846 #[test]
847 fn decode_auth_payload_too_short() {
848 let buf = wire_msg(b'R', &[0, 0]);
850 assert!(
851 BackendMessage::decode(&buf)
852 .unwrap_err()
853 .contains("too short")
854 );
855 }
856
857 #[test]
858 fn decode_auth_cleartext_password() {
859 let payload = 3i32.to_be_bytes();
860 let buf = wire_msg(b'R', &payload);
861 let (msg, _) = BackendMessage::decode(&buf).unwrap();
862 assert!(matches!(
863 msg,
864 BackendMessage::AuthenticationCleartextPassword
865 ));
866 }
867
868 #[test]
869 fn decode_auth_kerberos_v5() {
870 let payload = 2i32.to_be_bytes();
871 let buf = wire_msg(b'R', &payload);
872 let (msg, _) = BackendMessage::decode(&buf).unwrap();
873 assert!(matches!(msg, BackendMessage::AuthenticationKerberosV5));
874 }
875
876 #[test]
877 fn decode_auth_gss() {
878 let payload = 7i32.to_be_bytes();
879 let buf = wire_msg(b'R', &payload);
880 let (msg, _) = BackendMessage::decode(&buf).unwrap();
881 assert!(matches!(msg, BackendMessage::AuthenticationGSS));
882 }
883
884 #[test]
885 fn decode_auth_sspi() {
886 let payload = 9i32.to_be_bytes();
887 let buf = wire_msg(b'R', &payload);
888 let (msg, _) = BackendMessage::decode(&buf).unwrap();
889 assert!(matches!(msg, BackendMessage::AuthenticationSSPI));
890 }
891
892 #[test]
893 fn decode_auth_gss_continue() {
894 let mut payload = 8i32.to_be_bytes().to_vec();
895 payload.extend_from_slice(&[0xde, 0xad, 0xbe, 0xef]);
896 let buf = wire_msg(b'R', &payload);
897 let (msg, _) = BackendMessage::decode(&buf).unwrap();
898 match msg {
899 BackendMessage::AuthenticationGSSContinue(token) => {
900 assert_eq!(token, vec![0xde, 0xad, 0xbe, 0xef]);
901 }
902 _ => panic!("Expected AuthenticationGSSContinue"),
903 }
904 }
905
906 #[test]
907 fn decode_auth_md5_missing_salt() {
908 let mut payload = 5i32.to_be_bytes().to_vec();
910 payload.extend_from_slice(&[0, 0, 0]); let buf = wire_msg(b'R', &payload);
912 assert!(BackendMessage::decode(&buf).unwrap_err().contains("MD5"));
913 }
914
915 #[test]
916 fn decode_auth_md5_valid_salt() {
917 let mut payload = 5i32.to_be_bytes().to_vec();
918 payload.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
919 let buf = wire_msg(b'R', &payload);
920 let (msg, _) = BackendMessage::decode(&buf).unwrap();
921 match msg {
922 BackendMessage::AuthenticationMD5Password(salt) => {
923 assert_eq!(salt, [0xDE, 0xAD, 0xBE, 0xEF]);
924 }
925 _ => panic!("Expected MD5 auth"),
926 }
927 }
928
929 #[test]
930 fn decode_auth_unknown_type_returns_error() {
931 let payload = 99i32.to_be_bytes();
932 let buf = wire_msg(b'R', &payload);
933 assert!(
934 BackendMessage::decode(&buf)
935 .unwrap_err()
936 .contains("Unknown auth type")
937 );
938 }
939
940 #[test]
941 fn decode_auth_sasl_mechanisms() {
942 let mut payload = 10i32.to_be_bytes().to_vec();
943 payload.extend_from_slice(b"SCRAM-SHA-256\0\0"); let buf = wire_msg(b'R', &payload);
945 let (msg, _) = BackendMessage::decode(&buf).unwrap();
946 match msg {
947 BackendMessage::AuthenticationSASL(mechs) => {
948 assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
949 }
950 _ => panic!("Expected SASL auth"),
951 }
952 }
953
954 #[test]
957 fn decode_ready_for_query_idle() {
958 let buf = wire_msg(b'Z', b"I");
959 let (msg, _) = BackendMessage::decode(&buf).unwrap();
960 assert!(matches!(
961 msg,
962 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
963 ));
964 }
965
966 #[test]
967 fn decode_ready_for_query_in_transaction() {
968 let buf = wire_msg(b'Z', b"T");
969 let (msg, _) = BackendMessage::decode(&buf).unwrap();
970 assert!(matches!(
971 msg,
972 BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
973 ));
974 }
975
976 #[test]
977 fn decode_ready_for_query_failed() {
978 let buf = wire_msg(b'Z', b"E");
979 let (msg, _) = BackendMessage::decode(&buf).unwrap();
980 assert!(matches!(
981 msg,
982 BackendMessage::ReadyForQuery(TransactionStatus::Failed)
983 ));
984 }
985
986 #[test]
987 fn decode_ready_for_query_empty_payload() {
988 let buf = wire_msg(b'Z', &[]);
989 assert!(BackendMessage::decode(&buf).unwrap_err().contains("empty"));
990 }
991
992 #[test]
993 fn decode_ready_for_query_unknown_status() {
994 let buf = wire_msg(b'Z', b"X");
995 assert!(
996 BackendMessage::decode(&buf)
997 .unwrap_err()
998 .contains("Unknown transaction")
999 );
1000 }
1001
1002 #[test]
1005 fn decode_data_row_empty_columns() {
1006 let payload = 0i16.to_be_bytes();
1007 let buf = wire_msg(b'D', &payload);
1008 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1009 match msg {
1010 BackendMessage::DataRow(cols) => assert!(cols.is_empty()),
1011 _ => panic!("Expected DataRow"),
1012 }
1013 }
1014
1015 #[test]
1016 fn decode_data_row_with_null() {
1017 let mut payload = 1i16.to_be_bytes().to_vec();
1018 payload.extend_from_slice(&(-1i32).to_be_bytes()); let buf = wire_msg(b'D', &payload);
1020 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1021 match msg {
1022 BackendMessage::DataRow(cols) => {
1023 assert_eq!(cols.len(), 1);
1024 assert!(cols[0].is_none());
1025 }
1026 _ => panic!("Expected DataRow"),
1027 }
1028 }
1029
1030 #[test]
1031 fn decode_data_row_with_value() {
1032 let mut payload = 1i16.to_be_bytes().to_vec();
1033 let data = b"hello";
1034 payload.extend_from_slice(&(data.len() as i32).to_be_bytes());
1035 payload.extend_from_slice(data);
1036 let buf = wire_msg(b'D', &payload);
1037 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1038 match msg {
1039 BackendMessage::DataRow(cols) => {
1040 assert_eq!(cols.len(), 1);
1041 assert_eq!(cols[0].as_deref(), Some(b"hello".as_slice()));
1042 }
1043 _ => panic!("Expected DataRow"),
1044 }
1045 }
1046
1047 #[test]
1048 fn decode_data_row_negative_count_returns_error() {
1049 let payload = (-1i16).to_be_bytes();
1050 let buf = wire_msg(b'D', &payload);
1051 assert!(
1052 BackendMessage::decode(&buf)
1053 .unwrap_err()
1054 .contains("invalid column count")
1055 );
1056 }
1057
1058 #[test]
1059 fn decode_data_row_truncated_column_data() {
1060 let mut payload = 1i16.to_be_bytes().to_vec();
1061 payload.extend_from_slice(&100i32.to_be_bytes());
1063 let buf = wire_msg(b'D', &payload);
1064 assert!(
1065 BackendMessage::decode(&buf)
1066 .unwrap_err()
1067 .contains("truncated")
1068 );
1069 }
1070
1071 #[test]
1072 fn decode_data_row_payload_too_short() {
1073 let buf = wire_msg(b'D', &[0]); assert!(
1075 BackendMessage::decode(&buf)
1076 .unwrap_err()
1077 .contains("too short")
1078 );
1079 }
1080
1081 #[test]
1082 fn decode_data_row_claims_too_many_columns() {
1083 let payload = 1000i16.to_be_bytes();
1085 let buf = wire_msg(b'D', &payload);
1086 assert!(BackendMessage::decode(&buf).unwrap_err().contains("claims"));
1087 }
1088
1089 #[test]
1092 fn decode_row_description_zero_fields() {
1093 let payload = 0i16.to_be_bytes();
1094 let buf = wire_msg(b'T', &payload);
1095 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1096 match msg {
1097 BackendMessage::RowDescription(fields) => assert!(fields.is_empty()),
1098 _ => panic!("Expected RowDescription"),
1099 }
1100 }
1101
1102 #[test]
1103 fn decode_row_description_negative_count() {
1104 let payload = (-1i16).to_be_bytes();
1105 let buf = wire_msg(b'T', &payload);
1106 assert!(
1107 BackendMessage::decode(&buf)
1108 .unwrap_err()
1109 .contains("invalid field count")
1110 );
1111 }
1112
1113 #[test]
1114 fn decode_row_description_truncated_field() {
1115 let mut payload = 1i16.to_be_bytes().to_vec();
1116 payload.extend_from_slice(b"id\0"); let buf = wire_msg(b'T', &payload);
1119 assert!(
1120 BackendMessage::decode(&buf)
1121 .unwrap_err()
1122 .contains("truncated")
1123 );
1124 }
1125
1126 #[test]
1127 fn decode_row_description_single_field() {
1128 let mut payload = 1i16.to_be_bytes().to_vec();
1129 payload.extend_from_slice(b"id\0"); payload.extend_from_slice(&0u32.to_be_bytes()); payload.extend_from_slice(&0i16.to_be_bytes()); payload.extend_from_slice(&23u32.to_be_bytes()); payload.extend_from_slice(&4i16.to_be_bytes()); payload.extend_from_slice(&(-1i32).to_be_bytes()); payload.extend_from_slice(&0i16.to_be_bytes()); let buf = wire_msg(b'T', &payload);
1137 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1138 match msg {
1139 BackendMessage::RowDescription(fields) => {
1140 assert_eq!(fields.len(), 1);
1141 assert_eq!(fields[0].name, "id");
1142 assert_eq!(fields[0].type_oid, 23); }
1144 _ => panic!("Expected RowDescription"),
1145 }
1146 }
1147
1148 #[test]
1151 fn decode_backend_key_data() {
1152 let mut payload = 42i32.to_be_bytes().to_vec();
1153 payload.extend_from_slice(&99i32.to_be_bytes());
1154 let buf = wire_msg(b'K', &payload);
1155 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1156 match msg {
1157 BackendMessage::BackendKeyData {
1158 process_id,
1159 secret_key,
1160 } => {
1161 assert_eq!(process_id, 42);
1162 assert_eq!(secret_key, 99);
1163 }
1164 _ => panic!("Expected BackendKeyData"),
1165 }
1166 }
1167
1168 #[test]
1169 fn decode_backend_key_too_short() {
1170 let buf = wire_msg(b'K', &[0, 0, 0, 42]); assert!(
1172 BackendMessage::decode(&buf)
1173 .unwrap_err()
1174 .contains("too short")
1175 );
1176 }
1177
1178 #[test]
1181 fn decode_error_response_with_fields() {
1182 let mut payload = Vec::new();
1183 payload.push(b'S');
1184 payload.extend_from_slice(b"ERROR\0");
1185 payload.push(b'C');
1186 payload.extend_from_slice(b"42P01\0");
1187 payload.push(b'M');
1188 payload.extend_from_slice(b"relation does not exist\0");
1189 payload.push(0); let buf = wire_msg(b'E', &payload);
1191 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1192 match msg {
1193 BackendMessage::ErrorResponse(fields) => {
1194 assert_eq!(fields.severity, "ERROR");
1195 assert_eq!(fields.code, "42P01");
1196 assert_eq!(fields.message, "relation does not exist");
1197 }
1198 _ => panic!("Expected ErrorResponse"),
1199 }
1200 }
1201
1202 #[test]
1203 fn decode_error_response_empty() {
1204 let buf = wire_msg(b'E', &[0]); let (msg, _) = BackendMessage::decode(&buf).unwrap();
1206 match msg {
1207 BackendMessage::ErrorResponse(fields) => {
1208 assert!(fields.message.is_empty());
1209 }
1210 _ => panic!("Expected ErrorResponse"),
1211 }
1212 }
1213
1214 #[test]
1217 fn decode_command_complete() {
1218 let buf = wire_msg(b'C', b"INSERT 0 1\0");
1219 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1220 match msg {
1221 BackendMessage::CommandComplete(tag) => assert_eq!(tag, "INSERT 0 1"),
1222 _ => panic!("Expected CommandComplete"),
1223 }
1224 }
1225
1226 #[test]
1229 fn decode_parse_complete() {
1230 let buf = wire_msg(b'1', &[]);
1231 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1232 assert!(matches!(msg, BackendMessage::ParseComplete));
1233 }
1234
1235 #[test]
1236 fn decode_bind_complete() {
1237 let buf = wire_msg(b'2', &[]);
1238 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1239 assert!(matches!(msg, BackendMessage::BindComplete));
1240 }
1241
1242 #[test]
1243 fn decode_no_data() {
1244 let buf = wire_msg(b'n', &[]);
1245 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1246 assert!(matches!(msg, BackendMessage::NoData));
1247 }
1248
1249 #[test]
1250 fn decode_portal_suspended() {
1251 let buf = wire_msg(b's', &[]);
1252 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1253 assert!(matches!(msg, BackendMessage::PortalSuspended));
1254 }
1255
1256 #[test]
1257 fn decode_empty_query_response() {
1258 let buf = wire_msg(b'I', &[]);
1259 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1260 assert!(matches!(msg, BackendMessage::EmptyQueryResponse));
1261 }
1262
1263 #[test]
1266 fn decode_notification_response() {
1267 let mut payload = 1i32.to_be_bytes().to_vec();
1268 payload.extend_from_slice(b"my_channel\0");
1269 payload.extend_from_slice(b"hello world\0");
1270 let buf = wire_msg(b'A', &payload);
1271 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1272 match msg {
1273 BackendMessage::NotificationResponse {
1274 process_id,
1275 channel,
1276 payload,
1277 } => {
1278 assert_eq!(process_id, 1);
1279 assert_eq!(channel, "my_channel");
1280 assert_eq!(payload, "hello world");
1281 }
1282 _ => panic!("Expected NotificationResponse"),
1283 }
1284 }
1285
1286 #[test]
1287 fn decode_notification_too_short() {
1288 let buf = wire_msg(b'A', &[0, 0]); assert!(
1290 BackendMessage::decode(&buf)
1291 .unwrap_err()
1292 .contains("too short")
1293 );
1294 }
1295
1296 #[test]
1299 fn decode_copy_in_response_empty_payload() {
1300 let buf = wire_msg(b'G', &[]);
1301 assert!(BackendMessage::decode(&buf).unwrap_err().contains("Empty"));
1302 }
1303
1304 #[test]
1305 fn decode_copy_out_response_empty_payload() {
1306 let buf = wire_msg(b'H', &[]);
1307 assert!(BackendMessage::decode(&buf).unwrap_err().contains("Empty"));
1308 }
1309
1310 #[test]
1311 fn decode_copy_in_response_text_format() {
1312 let mut payload = vec![0u8]; payload.extend_from_slice(&1i16.to_be_bytes()); payload.push(0); let buf = wire_msg(b'G', &payload);
1316 let (msg, _) = BackendMessage::decode(&buf).unwrap();
1317 match msg {
1318 BackendMessage::CopyInResponse {
1319 format,
1320 column_formats,
1321 } => {
1322 assert_eq!(format, 0);
1323 assert_eq!(column_formats, vec![0]);
1324 }
1325 _ => panic!("Expected CopyInResponse"),
1326 }
1327 }
1328
1329 #[test]
1332 fn decode_consumed_length_is_correct() {
1333 let buf = wire_msg(b'Z', b"I");
1334 let (_, consumed) = BackendMessage::decode(&buf).unwrap();
1335 assert_eq!(consumed, buf.len());
1336 }
1337
1338 #[test]
1339 fn decode_with_trailing_data_only_consumes_one_message() {
1340 let mut buf = wire_msg(b'Z', b"I");
1341 buf.extend_from_slice(&wire_msg(b'Z', b"T")); let (msg, consumed) = BackendMessage::decode(&buf).unwrap();
1343 assert!(matches!(
1344 msg,
1345 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
1346 ));
1347 assert_eq!(consumed, 6); }
1350
1351 #[test]
1354 fn encode_sync() {
1355 let msg = FrontendMessage::Sync;
1356 let encoded = msg.encode();
1357 assert_eq!(encoded, vec![b'S', 0, 0, 0, 4]);
1358 }
1359
1360 #[test]
1361 fn encode_gss_response() {
1362 let msg = FrontendMessage::GSSResponse(vec![1, 2, 3, 4]);
1363 let encoded = msg.encode();
1364 assert_eq!(encoded[0], b'p');
1365 let len = i32::from_be_bytes([encoded[1], encoded[2], encoded[3], encoded[4]]);
1366 assert_eq!(len, 8);
1367 assert_eq!(&encoded[5..], &[1, 2, 3, 4]);
1368 }
1369
1370 #[test]
1371 fn encode_terminate() {
1372 let msg = FrontendMessage::Terminate;
1373 let encoded = msg.encode();
1374 assert_eq!(encoded, vec![b'X', 0, 0, 0, 4]);
1375 }
1376}