1use crate::{ProxyError, Result};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum MessageType {
12 Startup,
14 SSLRequest,
16 AuthRequest,
18 Password,
20 Query,
22 Parse,
24 Bind,
26 Describe,
28 Execute,
30 Sync,
32 Flush,
34 Close,
36 Terminate,
38 CopyData,
40 CopyDone,
42 CopyFail,
44 FunctionCall,
46 BackendKeyData,
48 ParameterStatus,
50 ReadyForQuery,
52 RowDescription,
54 DataRow,
56 CommandComplete,
58 EmptyQueryResponse,
60 ErrorResponse,
62 NoticeResponse,
64 NotificationResponse,
66 ParseComplete,
68 BindComplete,
70 CloseComplete,
72 PortalSuspended,
74 NoData,
76 ParameterDescription,
78 Unknown(u8),
80}
81
82impl MessageType {
83 pub fn from_tag(tag: u8) -> Self {
85 match tag {
86 b'Q' => MessageType::Query,
87 b'P' => MessageType::Parse,
88 b'B' => MessageType::Bind,
89 b'D' => MessageType::Describe,
90 b'E' => MessageType::Execute,
91 b'S' => MessageType::Sync,
92 b'H' => MessageType::Flush,
93 b'C' => MessageType::Close,
94 b'X' => MessageType::Terminate,
95 b'd' => MessageType::CopyData,
96 b'c' => MessageType::CopyDone,
97 b'f' => MessageType::CopyFail,
98 b'F' => MessageType::FunctionCall,
99 b'p' => MessageType::Password,
100 b'R' => MessageType::AuthRequest,
101 b'K' => MessageType::BackendKeyData,
102 b'Z' => MessageType::ReadyForQuery,
108 b'T' => MessageType::RowDescription,
109 b'I' => MessageType::EmptyQueryResponse,
110 b'N' => MessageType::NoticeResponse,
111 b'A' => MessageType::NotificationResponse,
112 b'1' => MessageType::ParseComplete,
113 b'2' => MessageType::BindComplete,
114 b'3' => MessageType::CloseComplete,
115 b's' => MessageType::PortalSuspended,
116 b'n' => MessageType::NoData,
117 b't' => MessageType::ParameterDescription,
118 _ => MessageType::Unknown(tag),
119 }
120 }
121
122 pub fn to_tag(&self) -> Option<u8> {
124 match self {
125 MessageType::Query => Some(b'Q'),
126 MessageType::Parse => Some(b'P'),
127 MessageType::Bind => Some(b'B'),
128 MessageType::Describe => Some(b'D'),
129 MessageType::Execute => Some(b'E'),
130 MessageType::Sync => Some(b'S'),
131 MessageType::Flush => Some(b'H'),
132 MessageType::Close => Some(b'C'),
133 MessageType::Terminate => Some(b'X'),
134 MessageType::CopyData => Some(b'd'),
135 MessageType::CopyDone => Some(b'c'),
136 MessageType::CopyFail => Some(b'f'),
137 MessageType::FunctionCall => Some(b'F'),
138 MessageType::Password => Some(b'p'),
139 MessageType::AuthRequest => Some(b'R'),
140 MessageType::BackendKeyData => Some(b'K'),
141 MessageType::ParameterStatus => Some(b'S'),
142 MessageType::ReadyForQuery => Some(b'Z'),
143 MessageType::RowDescription => Some(b'T'),
144 MessageType::DataRow => Some(b'D'),
145 MessageType::CommandComplete => Some(b'C'),
146 MessageType::EmptyQueryResponse => Some(b'I'),
147 MessageType::ErrorResponse => Some(b'E'),
148 MessageType::NoticeResponse => Some(b'N'),
149 MessageType::NotificationResponse => Some(b'A'),
150 MessageType::ParseComplete => Some(b'1'),
151 MessageType::BindComplete => Some(b'2'),
152 MessageType::CloseComplete => Some(b'3'),
153 MessageType::PortalSuspended => Some(b's'),
154 MessageType::NoData => Some(b'n'),
155 MessageType::ParameterDescription => Some(b't'),
156 _ => None,
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct Message {
164 pub msg_type: MessageType,
166 pub payload: BytesMut,
168}
169
170impl Message {
171 pub fn new(msg_type: MessageType, payload: BytesMut) -> Self {
173 Self { msg_type, payload }
174 }
175
176 pub fn empty(msg_type: MessageType) -> Self {
178 Self {
179 msg_type,
180 payload: BytesMut::new(),
181 }
182 }
183
184 pub fn encode(&self) -> BytesMut {
186 let mut buf = BytesMut::new();
187
188 if let Some(tag) = self.msg_type.to_tag() {
189 buf.put_u8(tag);
190 }
191
192 let len = self.payload.len() as u32 + 4;
194 buf.put_u32(len);
195 buf.extend_from_slice(&self.payload);
196
197 buf
198 }
199}
200
201pub struct ProtocolCodec {
203 max_message_size: usize,
205}
206
207impl Default for ProtocolCodec {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213impl ProtocolCodec {
214 pub fn new() -> Self {
216 Self {
217 max_message_size: 100 * 1024 * 1024, }
219 }
220
221 pub fn with_max_size(max_message_size: usize) -> Self {
223 Self { max_message_size }
224 }
225
226 pub fn decode_startup(&self, src: &mut BytesMut) -> Result<Option<StartupMessage>> {
228 if src.len() < 4 {
229 return Ok(None);
230 }
231
232 let len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
233
234 if len > self.max_message_size {
235 return Err(ProxyError::Protocol(format!(
236 "Message too large: {} bytes",
237 len
238 )));
239 }
240
241 if src.len() < len {
242 return Ok(None);
243 }
244
245 src.advance(4);
246 let protocol_version = src.get_u32();
247
248 if protocol_version == 80877103 {
250 return Ok(Some(StartupMessage::SSLRequest));
251 }
252
253 if protocol_version == 80877102 {
255 let pid = src.get_u32();
256 let key = src.get_u32();
257 return Ok(Some(StartupMessage::CancelRequest { pid, key }));
258 }
259
260 let mut params = HashMap::new();
262 let remaining = len - 8; let mut param_bytes = src.split_to(remaining);
264
265 while param_bytes.has_remaining() {
266 let key = read_cstring(&mut param_bytes)?;
267 if key.is_empty() {
268 break;
269 }
270 let value = read_cstring(&mut param_bytes)?;
271 params.insert(key, value);
272 }
273
274 Ok(Some(StartupMessage::Startup {
275 protocol_version,
276 params,
277 }))
278 }
279
280 pub fn decode_message(&self, src: &mut BytesMut) -> Result<Option<Message>> {
282 if src.len() < 5 {
283 return Ok(None);
284 }
285
286 let tag = src[0];
287 let len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
288
289 if len > self.max_message_size {
290 return Err(ProxyError::Protocol(format!(
291 "Message too large: {} bytes",
292 len
293 )));
294 }
295
296 let total_len = 1 + len;
298 if src.len() < total_len {
299 return Ok(None);
300 }
301
302 src.advance(5); let payload = src.split_to(len - 4); let msg_type = MessageType::from_tag(tag);
306 Ok(Some(Message::new(msg_type, payload)))
307 }
308
309 pub fn encode_message(&self, msg: &Message) -> BytesMut {
311 msg.encode()
312 }
313}
314
315#[derive(Debug, Clone)]
317pub enum StartupMessage {
318 Startup {
320 protocol_version: u32,
321 params: HashMap<String, String>,
322 },
323 SSLRequest,
325 CancelRequest { pid: u32, key: u32 },
327}
328
329fn read_cstring(buf: &mut BytesMut) -> Result<String> {
336 let end = buf
337 .iter()
338 .position(|&b| b == 0)
339 .ok_or_else(|| ProxyError::Protocol(
340 "unterminated cstring in protocol buffer".to_string(),
341 ))?;
342
343 let bytes = buf.split_to(end);
344 buf.advance(1); String::from_utf8(bytes.into())
347 .map_err(|e| ProxyError::Protocol(format!("Invalid UTF-8 in cstring: {}", e)))
348}
349
350fn write_cstring(buf: &mut BytesMut, s: &str) {
352 buf.extend_from_slice(s.as_bytes());
353 buf.put_u8(0);
354}
355
356pub fn query_text(payload: &[u8]) -> Option<&str> {
361 let end = payload.iter().position(|&b| b == 0)?;
362 std::str::from_utf8(&payload[..end]).ok()
363}
364
365pub fn starts_with_ci(s: &str, prefix: &str) -> bool {
368 s.len() >= prefix.len() && s.as_bytes()[..prefix.len()].eq_ignore_ascii_case(prefix.as_bytes())
369}
370
371pub fn contains_ci(haystack: &str, needle: &str) -> bool {
373 if needle.is_empty() {
374 return true;
375 }
376 if haystack.len() < needle.len() {
377 return false;
378 }
379 haystack
380 .as_bytes()
381 .windows(needle.len())
382 .any(|w| w.eq_ignore_ascii_case(needle.as_bytes()))
383}
384
385#[derive(Debug, Clone)]
387pub struct QueryMessage {
388 pub query: String,
389}
390
391impl QueryMessage {
392 pub fn parse(mut payload: BytesMut) -> Result<Self> {
394 let query = read_cstring(&mut payload)?;
395 Ok(Self { query })
396 }
397
398 pub fn encode(&self) -> Message {
400 let mut payload = BytesMut::new();
401 write_cstring(&mut payload, &self.query);
402 Message::new(MessageType::Query, payload)
403 }
404}
405
406#[derive(Debug, Clone)]
408pub struct ParseMessage {
409 pub name: String,
410 pub query: String,
411 pub param_types: Vec<u32>,
412}
413
414impl ParseMessage {
415 pub fn parse(mut payload: BytesMut) -> Result<Self> {
417 let name = read_cstring(&mut payload)?;
418 let query = read_cstring(&mut payload)?;
419
420 let num_params = payload.get_u16() as usize;
421 let mut param_types = Vec::with_capacity(num_params);
422
423 for _ in 0..num_params {
424 param_types.push(payload.get_u32());
425 }
426
427 Ok(Self {
428 name,
429 query,
430 param_types,
431 })
432 }
433
434 pub fn encode(&self) -> Message {
436 let mut payload = BytesMut::new();
437 write_cstring(&mut payload, &self.name);
438 write_cstring(&mut payload, &self.query);
439 payload.put_u16(self.param_types.len() as u16);
440 for &t in &self.param_types {
441 payload.put_u32(t);
442 }
443 Message::new(MessageType::Parse, payload)
444 }
445}
446
447#[derive(Debug, Clone)]
453pub struct BindMessage {
454 pub portal: String,
455 pub statement: String,
456 pub param_formats: Vec<i16>,
457 pub param_values: Vec<Option<Bytes>>,
458 pub result_formats: Vec<i16>,
459}
460
461impl BindMessage {
462 pub fn parse(mut payload: BytesMut) -> Result<Self> {
464 let portal = read_cstring(&mut payload)?;
465 let statement = read_cstring(&mut payload)?;
466
467 let num_formats = payload.get_u16() as usize;
469 let mut param_formats = Vec::with_capacity(num_formats);
470 for _ in 0..num_formats {
471 param_formats.push(payload.get_i16());
472 }
473
474 let num_values = payload.get_u16() as usize;
478 let mut param_values = Vec::with_capacity(num_values);
479 for _ in 0..num_values {
480 let len = payload.get_i32();
481 if len == -1 {
482 param_values.push(None);
483 } else {
484 let value = payload.split_to(len as usize).freeze();
485 param_values.push(Some(value));
486 }
487 }
488
489 let num_result_formats = payload.get_u16() as usize;
491 let mut result_formats = Vec::with_capacity(num_result_formats);
492 for _ in 0..num_result_formats {
493 result_formats.push(payload.get_i16());
494 }
495
496 Ok(Self {
497 portal,
498 statement,
499 param_formats,
500 param_values,
501 result_formats,
502 })
503 }
504}
505
506#[derive(Debug, Clone)]
508pub struct ExecuteMessage {
509 pub portal: String,
510 pub max_rows: i32,
511}
512
513impl ExecuteMessage {
514 pub fn parse(mut payload: BytesMut) -> Result<Self> {
516 let portal = read_cstring(&mut payload)?;
517 let max_rows = payload.get_i32();
518 Ok(Self { portal, max_rows })
519 }
520
521 pub fn encode(&self) -> Message {
523 let mut payload = BytesMut::new();
524 write_cstring(&mut payload, &self.portal);
525 payload.put_i32(self.max_rows);
526 Message::new(MessageType::Execute, payload)
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct ErrorResponse {
533 pub fields: HashMap<char, String>,
534}
535
536impl ErrorResponse {
537 pub fn parse(mut payload: BytesMut) -> Result<Self> {
539 let mut fields = HashMap::new();
540
541 while payload.has_remaining() {
542 let code = payload.get_u8();
543 if code == 0 {
544 break;
545 }
546 let value = read_cstring(&mut payload)?;
547 fields.insert(code as char, value);
548 }
549
550 Ok(Self { fields })
551 }
552
553 pub fn severity(&self) -> Option<&str> {
555 self.fields.get(&'S').map(|s| s.as_str())
556 }
557
558 pub fn code(&self) -> Option<&str> {
560 self.fields.get(&'C').map(|s| s.as_str())
561 }
562
563 pub fn message(&self) -> Option<&str> {
565 self.fields.get(&'M').map(|s| s.as_str())
566 }
567
568 pub fn encode(&self) -> Message {
570 let mut payload = BytesMut::new();
571 for (&code, value) in &self.fields {
572 payload.put_u8(code as u8);
573 write_cstring(&mut payload, value);
574 }
575 payload.put_u8(0);
576 Message::new(MessageType::ErrorResponse, payload)
577 }
578}
579
580#[derive(Debug, Clone, Copy, PartialEq, Eq)]
582pub enum TransactionStatus {
583 Idle,
585 InTransaction,
587 Failed,
589}
590
591impl TransactionStatus {
592 pub fn from_byte(b: u8) -> Self {
594 match b {
595 b'I' => TransactionStatus::Idle,
596 b'T' => TransactionStatus::InTransaction,
597 b'E' => TransactionStatus::Failed,
598 _ => TransactionStatus::Idle,
599 }
600 }
601
602 pub fn to_byte(&self) -> u8 {
604 match self {
605 TransactionStatus::Idle => b'I',
606 TransactionStatus::InTransaction => b'T',
607 TransactionStatus::Failed => b'E',
608 }
609 }
610}
611
612#[derive(Debug, Clone)]
614pub struct CommandComplete {
615 pub tag: String,
616}
617
618impl CommandComplete {
619 pub fn parse(mut payload: BytesMut) -> Result<Self> {
621 let tag = read_cstring(&mut payload)?;
622 Ok(Self { tag })
623 }
624
625 pub fn encode(&self) -> Message {
627 let mut payload = BytesMut::new();
628 write_cstring(&mut payload, &self.tag);
629 Message::new(MessageType::CommandComplete, payload)
630 }
631
632 pub fn rows_affected(&self) -> Option<u64> {
634 let parts: Vec<&str> = self.tag.split_whitespace().collect();
635 if parts.len() >= 2 {
636 parts.last()?.parse().ok()
637 } else {
638 None
639 }
640 }
641}
642
643#[derive(Debug, Clone)]
645pub enum AuthRequest {
646 Ok,
648 CleartextPassword,
650 Md5Password { salt: [u8; 4] },
652 SASL { mechanisms: Vec<String> },
654 SASLContinue { data: Vec<u8> },
656 SASLFinal { data: Vec<u8> },
658 Unknown(i32),
660}
661
662impl AuthRequest {
663 pub fn parse(mut payload: BytesMut) -> Result<Self> {
665 let auth_type = payload.get_i32();
666
667 Ok(match auth_type {
668 0 => AuthRequest::Ok,
669 3 => AuthRequest::CleartextPassword,
670 5 => {
671 let mut salt = [0u8; 4];
672 payload.copy_to_slice(&mut salt);
673 AuthRequest::Md5Password { salt }
674 }
675 10 => {
676 let mut mechanisms = Vec::new();
677 loop {
678 let mech = read_cstring(&mut payload)?;
679 if mech.is_empty() {
680 break;
681 }
682 mechanisms.push(mech);
683 }
684 AuthRequest::SASL { mechanisms }
685 }
686 11 => {
687 let data = payload.to_vec();
688 AuthRequest::SASLContinue { data }
689 }
690 12 => {
691 let data = payload.to_vec();
692 AuthRequest::SASLFinal { data }
693 }
694 _ => AuthRequest::Unknown(auth_type),
695 })
696 }
697
698 pub fn encode(&self) -> Message {
700 let mut payload = BytesMut::new();
701
702 match self {
703 AuthRequest::Ok => {
704 payload.put_i32(0);
705 }
706 AuthRequest::CleartextPassword => {
707 payload.put_i32(3);
708 }
709 AuthRequest::Md5Password { salt } => {
710 payload.put_i32(5);
711 payload.extend_from_slice(salt);
712 }
713 AuthRequest::SASL { mechanisms } => {
714 payload.put_i32(10);
715 for mech in mechanisms {
716 write_cstring(&mut payload, mech);
717 }
718 payload.put_u8(0);
719 }
720 AuthRequest::SASLContinue { data } => {
721 payload.put_i32(11);
722 payload.extend_from_slice(data);
723 }
724 AuthRequest::SASLFinal { data } => {
725 payload.put_i32(12);
726 payload.extend_from_slice(data);
727 }
728 AuthRequest::Unknown(t) => {
729 payload.put_i32(*t);
730 }
731 }
732
733 Message::new(MessageType::AuthRequest, payload)
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_message_type_round_trip() {
743 let types = vec![
744 MessageType::Query,
745 MessageType::Parse,
746 MessageType::Bind,
747 MessageType::Execute,
748 MessageType::Sync,
749 ];
750
751 for msg_type in types {
752 if let Some(tag) = msg_type.to_tag() {
753 let decoded = MessageType::from_tag(tag);
754 assert_eq!(decoded, msg_type);
755 }
756 }
757 }
758
759 #[test]
760 fn test_auth_request_tag_mapping() {
761 assert_eq!(MessageType::from_tag(b'R'), MessageType::AuthRequest);
764 assert_eq!(MessageType::AuthRequest.to_tag(), Some(b'R'));
765 }
766
767 #[test]
768 fn test_query_message() {
769 let query = QueryMessage {
770 query: "SELECT 1".to_string(),
771 };
772 let msg = query.encode();
773 assert_eq!(msg.msg_type, MessageType::Query);
774
775 let decoded = QueryMessage::parse(msg.payload).unwrap();
776 assert_eq!(decoded.query, "SELECT 1");
777 }
778
779 #[test]
780 fn test_error_response() {
781 let mut fields = HashMap::new();
782 fields.insert('S', "ERROR".to_string());
783 fields.insert('C', "42P01".to_string());
784 fields.insert('M', "relation does not exist".to_string());
785
786 let err = ErrorResponse { fields };
787 assert_eq!(err.severity(), Some("ERROR"));
788 assert_eq!(err.code(), Some("42P01"));
789 assert_eq!(err.message(), Some("relation does not exist"));
790 }
791
792 #[test]
793 fn test_command_complete() {
794 let cmd = CommandComplete {
795 tag: "INSERT 0 5".to_string(),
796 };
797 assert_eq!(cmd.rows_affected(), Some(5));
798
799 let cmd2 = CommandComplete {
800 tag: "SELECT 100".to_string(),
801 };
802 assert_eq!(cmd2.rows_affected(), Some(100));
803 }
804
805 #[test]
806 fn test_transaction_status() {
807 assert_eq!(TransactionStatus::from_byte(b'I'), TransactionStatus::Idle);
808 assert_eq!(
809 TransactionStatus::from_byte(b'T'),
810 TransactionStatus::InTransaction
811 );
812 assert_eq!(TransactionStatus::from_byte(b'E'), TransactionStatus::Failed);
813
814 assert_eq!(TransactionStatus::Idle.to_byte(), b'I');
815 assert_eq!(TransactionStatus::InTransaction.to_byte(), b'T');
816 assert_eq!(TransactionStatus::Failed.to_byte(), b'E');
817 }
818
819 #[test]
820 fn test_protocol_codec() {
821 let codec = ProtocolCodec::new();
822 let query = QueryMessage {
823 query: "SELECT 1".to_string(),
824 };
825 let msg = query.encode();
826 let encoded = codec.encode_message(&msg);
827
828 assert!(encoded.len() > 5);
829 assert_eq!(encoded[0], b'Q');
830 }
831
832 #[test]
836 fn test_read_cstring_unterminated() {
837 let mut buf = BytesMut::from("not-null-terminated");
838 let err = read_cstring(&mut buf).expect_err("should reject unterminated cstring");
839 assert!(
840 matches!(err, ProxyError::Protocol(_)),
841 "expected Protocol error, got {err:?}"
842 );
843 }
844
845 #[test]
848 fn test_read_cstring_sequence() {
849 let mut buf = BytesMut::new();
850 buf.put_slice(b"first\0second\0tail");
851 let a = read_cstring(&mut buf).unwrap();
852 let b = read_cstring(&mut buf).unwrap();
853 assert_eq!(a, "first");
854 assert_eq!(b, "second");
855 assert_eq!(&buf[..], b"tail");
856 }
857
858 #[test]
862 fn test_bind_message_param_values_are_bytes() {
863 let mut payload = BytesMut::new();
864 payload.put_u8(0);
866 payload.put_u8(0);
867 payload.put_u16(1);
869 payload.put_i16(0);
870 payload.put_u16(2);
872 payload.put_i32(2);
873 payload.put_slice(b"hi");
874 payload.put_i32(-1);
875 payload.put_u16(0);
877
878 let bind = BindMessage::parse(payload).expect("parse failed");
879 assert_eq!(bind.param_values.len(), 2);
880 match &bind.param_values[0] {
881 Some(b) => assert_eq!(b.as_ref(), b"hi"),
882 None => panic!("first param must be Some"),
883 }
884 assert!(bind.param_values[1].is_none());
885 }
886}