1use crate::{ProxyError, Result};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum MessageType {
12 Startup,
14 SSLRequest,
16 AuthRequest,
18 Password,
20 Query,
22 Parse,
24 Bind,
26 Describe,
28 Execute,
30 Sync,
32 Flush,
34 Close,
36 Terminate,
38 CopyData,
40 CopyDone,
42 CopyFail,
44 FunctionCall,
46 BackendKeyData,
48 ParameterStatus,
50 ReadyForQuery,
52 RowDescription,
54 DataRow,
56 CommandComplete,
58 EmptyQueryResponse,
60 ErrorResponse,
62 NoticeResponse,
64 NotificationResponse,
66 ParseComplete,
68 BindComplete,
70 CloseComplete,
72 PortalSuspended,
74 NoData,
76 ParameterDescription,
78 Unknown(u8),
80}
81
82impl MessageType {
83 pub fn from_tag(tag: u8) -> Self {
85 match tag {
86 b'Q' => MessageType::Query,
87 b'P' => MessageType::Parse,
88 b'B' => MessageType::Bind,
89 b'D' => MessageType::Describe,
90 b'E' => MessageType::Execute,
91 b'S' => MessageType::Sync,
92 b'H' => MessageType::Flush,
93 b'C' => MessageType::Close,
94 b'X' => MessageType::Terminate,
95 b'd' => MessageType::CopyData,
96 b'c' => MessageType::CopyDone,
97 b'f' => MessageType::CopyFail,
98 b'F' => MessageType::FunctionCall,
99 b'p' => MessageType::Password,
100 b'R' => MessageType::AuthRequest,
101 b'K' => MessageType::BackendKeyData,
102 b'Z' => MessageType::ReadyForQuery,
108 b'T' => MessageType::RowDescription,
109 b'I' => MessageType::EmptyQueryResponse,
110 b'N' => MessageType::NoticeResponse,
111 b'A' => MessageType::NotificationResponse,
112 b'1' => MessageType::ParseComplete,
113 b'2' => MessageType::BindComplete,
114 b'3' => MessageType::CloseComplete,
115 b's' => MessageType::PortalSuspended,
116 b'n' => MessageType::NoData,
117 b't' => MessageType::ParameterDescription,
118 _ => MessageType::Unknown(tag),
119 }
120 }
121
122 pub fn to_tag(&self) -> Option<u8> {
124 match self {
125 MessageType::Query => Some(b'Q'),
126 MessageType::Parse => Some(b'P'),
127 MessageType::Bind => Some(b'B'),
128 MessageType::Describe => Some(b'D'),
129 MessageType::Execute => Some(b'E'),
130 MessageType::Sync => Some(b'S'),
131 MessageType::Flush => Some(b'H'),
132 MessageType::Close => Some(b'C'),
133 MessageType::Terminate => Some(b'X'),
134 MessageType::CopyData => Some(b'd'),
135 MessageType::CopyDone => Some(b'c'),
136 MessageType::CopyFail => Some(b'f'),
137 MessageType::FunctionCall => Some(b'F'),
138 MessageType::Password => Some(b'p'),
139 MessageType::AuthRequest => Some(b'R'),
140 MessageType::BackendKeyData => Some(b'K'),
141 MessageType::ParameterStatus => Some(b'S'),
142 MessageType::ReadyForQuery => Some(b'Z'),
143 MessageType::RowDescription => Some(b'T'),
144 MessageType::DataRow => Some(b'D'),
145 MessageType::CommandComplete => Some(b'C'),
146 MessageType::EmptyQueryResponse => Some(b'I'),
147 MessageType::ErrorResponse => Some(b'E'),
148 MessageType::NoticeResponse => Some(b'N'),
149 MessageType::NotificationResponse => Some(b'A'),
150 MessageType::ParseComplete => Some(b'1'),
151 MessageType::BindComplete => Some(b'2'),
152 MessageType::CloseComplete => Some(b'3'),
153 MessageType::PortalSuspended => Some(b's'),
154 MessageType::NoData => Some(b'n'),
155 MessageType::ParameterDescription => Some(b't'),
156 _ => None,
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct Message {
164 pub msg_type: MessageType,
166 pub payload: BytesMut,
168}
169
170impl Message {
171 pub fn new(msg_type: MessageType, payload: BytesMut) -> Self {
173 Self { msg_type, payload }
174 }
175
176 pub fn empty(msg_type: MessageType) -> Self {
178 Self {
179 msg_type,
180 payload: BytesMut::new(),
181 }
182 }
183
184 pub fn encode(&self) -> BytesMut {
186 let mut buf = BytesMut::new();
187
188 if let Some(tag) = self.msg_type.to_tag() {
189 buf.put_u8(tag);
190 }
191
192 let len = self.payload.len() as u32 + 4;
194 buf.put_u32(len);
195 buf.extend_from_slice(&self.payload);
196
197 buf
198 }
199}
200
201pub struct ProtocolCodec {
203 max_message_size: usize,
205}
206
207impl Default for ProtocolCodec {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213impl ProtocolCodec {
214 pub fn new() -> Self {
216 Self {
217 max_message_size: 100 * 1024 * 1024, }
219 }
220
221 pub fn with_max_size(max_message_size: usize) -> Self {
223 Self { max_message_size }
224 }
225
226 pub fn decode_startup(&self, src: &mut BytesMut) -> Result<Option<StartupMessage>> {
228 if src.len() < 4 {
229 return Ok(None);
230 }
231
232 let len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
233
234 if len > self.max_message_size {
235 return Err(ProxyError::Protocol(format!(
236 "Message too large: {} bytes",
237 len
238 )));
239 }
240
241 if src.len() < len {
242 return Ok(None);
243 }
244
245 src.advance(4);
246 let protocol_version = src.get_u32();
247
248 if protocol_version == 80877103 {
250 return Ok(Some(StartupMessage::SSLRequest));
251 }
252
253 if protocol_version == 80877102 {
255 let pid = src.get_u32();
256 let key = src.get_u32();
257 return Ok(Some(StartupMessage::CancelRequest { pid, key }));
258 }
259
260 let mut params = HashMap::new();
262 let remaining = len - 8; let mut param_bytes = src.split_to(remaining);
264
265 while param_bytes.has_remaining() {
266 let key = read_cstring(&mut param_bytes)?;
267 if key.is_empty() {
268 break;
269 }
270 let value = read_cstring(&mut param_bytes)?;
271 params.insert(key, value);
272 }
273
274 Ok(Some(StartupMessage::Startup {
275 protocol_version,
276 params,
277 }))
278 }
279
280 pub fn decode_message(&self, src: &mut BytesMut) -> Result<Option<Message>> {
282 if src.len() < 5 {
283 return Ok(None);
284 }
285
286 let tag = src[0];
287 let len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
288
289 if len > self.max_message_size {
290 return Err(ProxyError::Protocol(format!(
291 "Message too large: {} bytes",
292 len
293 )));
294 }
295
296 let total_len = 1 + len;
298 if src.len() < total_len {
299 return Ok(None);
300 }
301
302 src.advance(5); let payload = src.split_to(len - 4); let msg_type = MessageType::from_tag(tag);
306 Ok(Some(Message::new(msg_type, payload)))
307 }
308
309 pub fn encode_message(&self, msg: &Message) -> BytesMut {
311 msg.encode()
312 }
313}
314
315#[derive(Debug, Clone)]
317pub enum StartupMessage {
318 Startup {
320 protocol_version: u32,
321 params: HashMap<String, String>,
322 },
323 SSLRequest,
325 CancelRequest { pid: u32, key: u32 },
327}
328
329fn read_cstring(buf: &mut BytesMut) -> Result<String> {
336 let end = buf.iter().position(|&b| b == 0).ok_or_else(|| {
337 ProxyError::Protocol("unterminated cstring in protocol buffer".to_string())
338 })?;
339
340 let bytes = buf.split_to(end);
341 buf.advance(1); String::from_utf8(bytes.into())
344 .map_err(|e| ProxyError::Protocol(format!("Invalid UTF-8 in cstring: {}", e)))
345}
346
347fn write_cstring(buf: &mut BytesMut, s: &str) {
349 buf.extend_from_slice(s.as_bytes());
350 buf.put_u8(0);
351}
352
353pub fn query_text(payload: &[u8]) -> Option<&str> {
358 let end = payload.iter().position(|&b| b == 0)?;
359 std::str::from_utf8(&payload[..end]).ok()
360}
361
362pub fn starts_with_ci(s: &str, prefix: &str) -> bool {
365 s.len() >= prefix.len() && s.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes())
366}
367
368pub fn contains_ci(haystack: &str, needle: &str) -> bool {
370 if needle.is_empty() {
371 return true;
372 }
373 if haystack.len() < needle.len() {
374 return false;
375 }
376 haystack
377 .as_bytes()
378 .windows(needle.len())
379 .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
380}
381
382#[derive(Debug, Clone)]
384pub struct QueryMessage {
385 pub query: String,
386}
387
388impl QueryMessage {
389 pub fn parse(mut payload: BytesMut) -> Result<Self> {
391 let query = read_cstring(&mut payload)?;
392 Ok(Self { query })
393 }
394
395 pub fn encode(&self) -> Message {
397 let mut payload = BytesMut::new();
398 write_cstring(&mut payload, &self.query);
399 Message::new(MessageType::Query, payload)
400 }
401}
402
403#[derive(Debug, Clone)]
405pub struct ParseMessage {
406 pub name: String,
407 pub query: String,
408 pub param_types: Vec<u32>,
409}
410
411impl ParseMessage {
412 pub fn parse(mut payload: BytesMut) -> Result<Self> {
414 let name = read_cstring(&mut payload)?;
415 let query = read_cstring(&mut payload)?;
416
417 let num_params = payload.get_u16() as usize;
418 let mut param_types = Vec::with_capacity(num_params);
419
420 for _ in 0..num_params {
421 param_types.push(payload.get_u32());
422 }
423
424 Ok(Self {
425 name,
426 query,
427 param_types,
428 })
429 }
430
431 pub fn encode(&self) -> Message {
433 let mut payload = BytesMut::new();
434 write_cstring(&mut payload, &self.name);
435 write_cstring(&mut payload, &self.query);
436 payload.put_u16(self.param_types.len() as u16);
437 for &t in &self.param_types {
438 payload.put_u32(t);
439 }
440 Message::new(MessageType::Parse, payload)
441 }
442}
443
444#[derive(Debug, Clone)]
450pub struct BindMessage {
451 pub portal: String,
452 pub statement: String,
453 pub param_formats: Vec<i16>,
454 pub param_values: Vec<Option<Bytes>>,
455 pub result_formats: Vec<i16>,
456}
457
458impl BindMessage {
459 pub fn parse(mut payload: BytesMut) -> Result<Self> {
461 let portal = read_cstring(&mut payload)?;
462 let statement = read_cstring(&mut payload)?;
463
464 let num_formats = payload.get_u16() as usize;
466 let mut param_formats = Vec::with_capacity(num_formats);
467 for _ in 0..num_formats {
468 param_formats.push(payload.get_i16());
469 }
470
471 let num_values = payload.get_u16() as usize;
475 let mut param_values = Vec::with_capacity(num_values);
476 for _ in 0..num_values {
477 let len = payload.get_i32();
478 if len == -1 {
479 param_values.push(None);
480 } else {
481 let value = payload.split_to(len as usize).freeze();
482 param_values.push(Some(value));
483 }
484 }
485
486 let num_result_formats = payload.get_u16() as usize;
488 let mut result_formats = Vec::with_capacity(num_result_formats);
489 for _ in 0..num_result_formats {
490 result_formats.push(payload.get_i16());
491 }
492
493 Ok(Self {
494 portal,
495 statement,
496 param_formats,
497 param_values,
498 result_formats,
499 })
500 }
501}
502
503#[derive(Debug, Clone)]
505pub struct ExecuteMessage {
506 pub portal: String,
507 pub max_rows: i32,
508}
509
510impl ExecuteMessage {
511 pub fn parse(mut payload: BytesMut) -> Result<Self> {
513 let portal = read_cstring(&mut payload)?;
514 let max_rows = payload.get_i32();
515 Ok(Self { portal, max_rows })
516 }
517
518 pub fn encode(&self) -> Message {
520 let mut payload = BytesMut::new();
521 write_cstring(&mut payload, &self.portal);
522 payload.put_i32(self.max_rows);
523 Message::new(MessageType::Execute, payload)
524 }
525}
526
527#[derive(Debug, Clone)]
529pub struct ErrorResponse {
530 pub fields: HashMap<char, String>,
531}
532
533impl ErrorResponse {
534 pub fn parse(mut payload: BytesMut) -> Result<Self> {
536 let mut fields = HashMap::new();
537
538 while payload.has_remaining() {
539 let code = payload.get_u8();
540 if code == 0 {
541 break;
542 }
543 let value = read_cstring(&mut payload)?;
544 fields.insert(code as char, value);
545 }
546
547 Ok(Self { fields })
548 }
549
550 pub fn severity(&self) -> Option<&str> {
552 self.fields.get(&'S').map(|s| s.as_str())
553 }
554
555 pub fn code(&self) -> Option<&str> {
557 self.fields.get(&'C').map(|s| s.as_str())
558 }
559
560 pub fn message(&self) -> Option<&str> {
562 self.fields.get(&'M').map(|s| s.as_str())
563 }
564
565 pub fn encode(&self) -> Message {
567 let mut payload = BytesMut::new();
568 for (&code, value) in &self.fields {
569 payload.put_u8(code as u8);
570 write_cstring(&mut payload, value);
571 }
572 payload.put_u8(0);
573 Message::new(MessageType::ErrorResponse, payload)
574 }
575}
576
577#[derive(Debug, Clone, Copy, PartialEq, Eq)]
579pub enum TransactionStatus {
580 Idle,
582 InTransaction,
584 Failed,
586}
587
588impl TransactionStatus {
589 pub fn from_byte(b: u8) -> Self {
591 match b {
592 b'I' => TransactionStatus::Idle,
593 b'T' => TransactionStatus::InTransaction,
594 b'E' => TransactionStatus::Failed,
595 _ => TransactionStatus::Idle,
596 }
597 }
598
599 pub fn to_byte(&self) -> u8 {
601 match self {
602 TransactionStatus::Idle => b'I',
603 TransactionStatus::InTransaction => b'T',
604 TransactionStatus::Failed => b'E',
605 }
606 }
607}
608
609#[derive(Debug, Clone)]
611pub struct CommandComplete {
612 pub tag: String,
613}
614
615impl CommandComplete {
616 pub fn parse(mut payload: BytesMut) -> Result<Self> {
618 let tag = read_cstring(&mut payload)?;
619 Ok(Self { tag })
620 }
621
622 pub fn encode(&self) -> Message {
624 let mut payload = BytesMut::new();
625 write_cstring(&mut payload, &self.tag);
626 Message::new(MessageType::CommandComplete, payload)
627 }
628
629 pub fn rows_affected(&self) -> Option<u64> {
631 let parts: Vec<&str> = self.tag.split_whitespace().collect();
632 if parts.len() >= 2 {
633 parts.last()?.parse().ok()
634 } else {
635 None
636 }
637 }
638}
639
640#[derive(Debug, Clone)]
642pub enum AuthRequest {
643 Ok,
645 CleartextPassword,
647 Md5Password { salt: [u8; 4] },
649 SASL { mechanisms: Vec<String> },
651 SASLContinue { data: Vec<u8> },
653 SASLFinal { data: Vec<u8> },
655 Unknown(i32),
657}
658
659impl AuthRequest {
660 pub fn parse(mut payload: BytesMut) -> Result<Self> {
662 let auth_type = payload.get_i32();
663
664 Ok(match auth_type {
665 0 => AuthRequest::Ok,
666 3 => AuthRequest::CleartextPassword,
667 5 => {
668 let mut salt = [0u8; 4];
669 payload.copy_to_slice(&mut salt);
670 AuthRequest::Md5Password { salt }
671 }
672 10 => {
673 let mut mechanisms = Vec::new();
674 loop {
675 let mech = read_cstring(&mut payload)?;
676 if mech.is_empty() {
677 break;
678 }
679 mechanisms.push(mech);
680 }
681 AuthRequest::SASL { mechanisms }
682 }
683 11 => {
684 let data = payload.to_vec();
685 AuthRequest::SASLContinue { data }
686 }
687 12 => {
688 let data = payload.to_vec();
689 AuthRequest::SASLFinal { data }
690 }
691 _ => AuthRequest::Unknown(auth_type),
692 })
693 }
694
695 pub fn encode(&self) -> Message {
697 let mut payload = BytesMut::new();
698
699 match self {
700 AuthRequest::Ok => {
701 payload.put_i32(0);
702 }
703 AuthRequest::CleartextPassword => {
704 payload.put_i32(3);
705 }
706 AuthRequest::Md5Password { salt } => {
707 payload.put_i32(5);
708 payload.extend_from_slice(salt);
709 }
710 AuthRequest::SASL { mechanisms } => {
711 payload.put_i32(10);
712 for mech in mechanisms {
713 write_cstring(&mut payload, mech);
714 }
715 payload.put_u8(0);
716 }
717 AuthRequest::SASLContinue { data } => {
718 payload.put_i32(11);
719 payload.extend_from_slice(data);
720 }
721 AuthRequest::SASLFinal { data } => {
722 payload.put_i32(12);
723 payload.extend_from_slice(data);
724 }
725 AuthRequest::Unknown(t) => {
726 payload.put_i32(*t);
727 }
728 }
729
730 Message::new(MessageType::AuthRequest, payload)
731 }
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737
738 #[test]
739 fn test_message_type_round_trip() {
740 let types = vec![
741 MessageType::Query,
742 MessageType::Parse,
743 MessageType::Bind,
744 MessageType::Execute,
745 MessageType::Sync,
746 ];
747
748 for msg_type in types {
749 if let Some(tag) = msg_type.to_tag() {
750 let decoded = MessageType::from_tag(tag);
751 assert_eq!(decoded, msg_type);
752 }
753 }
754 }
755
756 #[test]
757 fn test_auth_request_tag_mapping() {
758 assert_eq!(MessageType::from_tag(b'R'), MessageType::AuthRequest);
761 assert_eq!(MessageType::AuthRequest.to_tag(), Some(b'R'));
762 }
763
764 #[test]
765 fn test_query_message() {
766 let query = QueryMessage {
767 query: "SELECT 1".to_string(),
768 };
769 let msg = query.encode();
770 assert_eq!(msg.msg_type, MessageType::Query);
771
772 let decoded = QueryMessage::parse(msg.payload).unwrap();
773 assert_eq!(decoded.query, "SELECT 1");
774 }
775
776 #[test]
777 fn test_error_response() {
778 let mut fields = HashMap::new();
779 fields.insert('S', "ERROR".to_string());
780 fields.insert('C', "42P01".to_string());
781 fields.insert('M', "relation does not exist".to_string());
782
783 let err = ErrorResponse { fields };
784 assert_eq!(err.severity(), Some("ERROR"));
785 assert_eq!(err.code(), Some("42P01"));
786 assert_eq!(err.message(), Some("relation does not exist"));
787 }
788
789 #[test]
790 fn test_command_complete() {
791 let cmd = CommandComplete {
792 tag: "INSERT 0 5".to_string(),
793 };
794 assert_eq!(cmd.rows_affected(), Some(5));
795
796 let cmd2 = CommandComplete {
797 tag: "SELECT 100".to_string(),
798 };
799 assert_eq!(cmd2.rows_affected(), Some(100));
800 }
801
802 #[test]
803 fn test_transaction_status() {
804 assert_eq!(TransactionStatus::from_byte(b'I'), TransactionStatus::Idle);
805 assert_eq!(
806 TransactionStatus::from_byte(b'T'),
807 TransactionStatus::InTransaction
808 );
809 assert_eq!(
810 TransactionStatus::from_byte(b'E'),
811 TransactionStatus::Failed
812 );
813
814 assert_eq!(TransactionStatus::Idle.to_byte(), b'I');
815 assert_eq!(TransactionStatus::InTransaction.to_byte(), b'T');
816 assert_eq!(TransactionStatus::Failed.to_byte(), b'E');
817 }
818
819 #[test]
820 fn test_protocol_codec() {
821 let codec = ProtocolCodec::new();
822 let query = QueryMessage {
823 query: "SELECT 1".to_string(),
824 };
825 let msg = query.encode();
826 let encoded = codec.encode_message(&msg);
827
828 assert!(encoded.len() > 5);
829 assert_eq!(encoded[0], b'Q');
830 }
831
832 #[test]
836 fn test_read_cstring_unterminated() {
837 let mut buf = BytesMut::from("not-null-terminated");
838 let err = read_cstring(&mut buf).expect_err("should reject unterminated cstring");
839 assert!(
840 matches!(err, ProxyError::Protocol(_)),
841 "expected Protocol error, got {err:?}"
842 );
843 }
844
845 #[test]
848 fn test_read_cstring_sequence() {
849 let mut buf = BytesMut::new();
850 buf.put_slice(b"first\0second\0tail");
851 let a = read_cstring(&mut buf).unwrap();
852 let b = read_cstring(&mut buf).unwrap();
853 assert_eq!(a, "first");
854 assert_eq!(b, "second");
855 assert_eq!(&buf[..], b"tail");
856 }
857
858 #[test]
862 fn test_bind_message_param_values_are_bytes() {
863 let mut payload = BytesMut::new();
864 payload.put_u8(0);
866 payload.put_u8(0);
867 payload.put_u16(1);
869 payload.put_i16(0);
870 payload.put_u16(2);
872 payload.put_i32(2);
873 payload.put_slice(b"hi");
874 payload.put_i32(-1);
875 payload.put_u16(0);
877
878 let bind = BindMessage::parse(payload).expect("parse failed");
879 assert_eq!(bind.param_values.len(), 2);
880 match &bind.param_values[0] {
881 Some(b) => assert_eq!(b.as_ref(), b"hi"),
882 None => panic!("first param must be Some"),
883 }
884 assert!(bind.param_values[1].is_none());
885 }
886}