1use crate::{ProxyError, Result};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum MessageType {
12 Startup,
14 SSLRequest,
16 AuthRequest,
18 Password,
20 Query,
22 Parse,
24 Bind,
26 Describe,
28 Execute,
30 Sync,
32 Flush,
34 Close,
36 Terminate,
38 CopyData,
40 CopyDone,
42 CopyFail,
44 FunctionCall,
46 BackendKeyData,
48 ParameterStatus,
50 ReadyForQuery,
52 RowDescription,
54 DataRow,
56 CommandComplete,
58 EmptyQueryResponse,
60 ErrorResponse,
62 NoticeResponse,
64 NotificationResponse,
66 ParseComplete,
68 BindComplete,
70 CloseComplete,
72 PortalSuspended,
74 NoData,
76 ParameterDescription,
78 Unknown(u8),
80}
81
82impl MessageType {
83 pub fn from_tag(tag: u8) -> Self {
85 match tag {
86 b'Q' => MessageType::Query,
87 b'P' => MessageType::Parse,
88 b'B' => MessageType::Bind,
89 b'D' => MessageType::Describe,
90 b'E' => MessageType::Execute,
91 b'S' => MessageType::Sync,
92 b'H' => MessageType::Flush,
93 b'C' => MessageType::Close,
94 b'X' => MessageType::Terminate,
95 b'd' => MessageType::CopyData,
96 b'c' => MessageType::CopyDone,
97 b'f' => MessageType::CopyFail,
98 b'F' => MessageType::FunctionCall,
99 b'p' => MessageType::Password,
100 b'K' => MessageType::BackendKeyData,
101 b'Z' => MessageType::ReadyForQuery,
107 b'T' => MessageType::RowDescription,
108 b'I' => MessageType::EmptyQueryResponse,
109 b'N' => MessageType::NoticeResponse,
110 b'A' => MessageType::NotificationResponse,
111 b'1' => MessageType::ParseComplete,
112 b'2' => MessageType::BindComplete,
113 b'3' => MessageType::CloseComplete,
114 b's' => MessageType::PortalSuspended,
115 b'n' => MessageType::NoData,
116 b't' => MessageType::ParameterDescription,
117 _ => MessageType::Unknown(tag),
118 }
119 }
120
121 pub fn to_tag(&self) -> Option<u8> {
123 match self {
124 MessageType::Query => Some(b'Q'),
125 MessageType::Parse => Some(b'P'),
126 MessageType::Bind => Some(b'B'),
127 MessageType::Describe => Some(b'D'),
128 MessageType::Execute => Some(b'E'),
129 MessageType::Sync => Some(b'S'),
130 MessageType::Flush => Some(b'H'),
131 MessageType::Close => Some(b'C'),
132 MessageType::Terminate => Some(b'X'),
133 MessageType::CopyData => Some(b'd'),
134 MessageType::CopyDone => Some(b'c'),
135 MessageType::CopyFail => Some(b'f'),
136 MessageType::FunctionCall => Some(b'F'),
137 MessageType::Password => Some(b'p'),
138 MessageType::BackendKeyData => Some(b'K'),
139 MessageType::ParameterStatus => Some(b'S'),
140 MessageType::ReadyForQuery => Some(b'Z'),
141 MessageType::RowDescription => Some(b'T'),
142 MessageType::DataRow => Some(b'D'),
143 MessageType::CommandComplete => Some(b'C'),
144 MessageType::EmptyQueryResponse => Some(b'I'),
145 MessageType::ErrorResponse => Some(b'E'),
146 MessageType::NoticeResponse => Some(b'N'),
147 MessageType::NotificationResponse => Some(b'A'),
148 MessageType::ParseComplete => Some(b'1'),
149 MessageType::BindComplete => Some(b'2'),
150 MessageType::CloseComplete => Some(b'3'),
151 MessageType::PortalSuspended => Some(b's'),
152 MessageType::NoData => Some(b'n'),
153 MessageType::ParameterDescription => Some(b't'),
154 _ => None,
155 }
156 }
157}
158
159#[derive(Debug, Clone)]
161pub struct Message {
162 pub msg_type: MessageType,
164 pub payload: BytesMut,
166}
167
168impl Message {
169 pub fn new(msg_type: MessageType, payload: BytesMut) -> Self {
171 Self { msg_type, payload }
172 }
173
174 pub fn empty(msg_type: MessageType) -> Self {
176 Self {
177 msg_type,
178 payload: BytesMut::new(),
179 }
180 }
181
182 pub fn encode(&self) -> BytesMut {
184 let mut buf = BytesMut::new();
185
186 if let Some(tag) = self.msg_type.to_tag() {
187 buf.put_u8(tag);
188 }
189
190 let len = self.payload.len() as u32 + 4;
192 buf.put_u32(len);
193 buf.extend_from_slice(&self.payload);
194
195 buf
196 }
197}
198
199pub struct ProtocolCodec {
201 max_message_size: usize,
203}
204
205impl Default for ProtocolCodec {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl ProtocolCodec {
212 pub fn new() -> Self {
214 Self {
215 max_message_size: 100 * 1024 * 1024, }
217 }
218
219 pub fn with_max_size(max_message_size: usize) -> Self {
221 Self { max_message_size }
222 }
223
224 pub fn decode_startup(&self, src: &mut BytesMut) -> Result<Option<StartupMessage>> {
226 if src.len() < 4 {
227 return Ok(None);
228 }
229
230 let len = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
231
232 if len > self.max_message_size {
233 return Err(ProxyError::Protocol(format!(
234 "Message too large: {} bytes",
235 len
236 )));
237 }
238
239 if src.len() < len {
240 return Ok(None);
241 }
242
243 src.advance(4);
244 let protocol_version = src.get_u32();
245
246 if protocol_version == 80877103 {
248 return Ok(Some(StartupMessage::SSLRequest));
249 }
250
251 if protocol_version == 80877102 {
253 let pid = src.get_u32();
254 let key = src.get_u32();
255 return Ok(Some(StartupMessage::CancelRequest { pid, key }));
256 }
257
258 let mut params = HashMap::new();
260 let remaining = len - 8; let mut param_bytes = src.split_to(remaining);
262
263 while param_bytes.has_remaining() {
264 let key = read_cstring(&mut param_bytes)?;
265 if key.is_empty() {
266 break;
267 }
268 let value = read_cstring(&mut param_bytes)?;
269 params.insert(key, value);
270 }
271
272 Ok(Some(StartupMessage::Startup {
273 protocol_version,
274 params,
275 }))
276 }
277
278 pub fn decode_message(&self, src: &mut BytesMut) -> Result<Option<Message>> {
280 if src.len() < 5 {
281 return Ok(None);
282 }
283
284 let tag = src[0];
285 let len = u32::from_be_bytes([src[1], src[2], src[3], src[4]]) as usize;
286
287 if len > self.max_message_size {
288 return Err(ProxyError::Protocol(format!(
289 "Message too large: {} bytes",
290 len
291 )));
292 }
293
294 let total_len = 1 + len;
296 if src.len() < total_len {
297 return Ok(None);
298 }
299
300 src.advance(5); let payload = src.split_to(len - 4); let msg_type = MessageType::from_tag(tag);
304 Ok(Some(Message::new(msg_type, payload)))
305 }
306
307 pub fn encode_message(&self, msg: &Message) -> BytesMut {
309 msg.encode()
310 }
311}
312
313#[derive(Debug, Clone)]
315pub enum StartupMessage {
316 Startup {
318 protocol_version: u32,
319 params: HashMap<String, String>,
320 },
321 SSLRequest,
323 CancelRequest { pid: u32, key: u32 },
325}
326
327fn read_cstring(buf: &mut BytesMut) -> Result<String> {
334 let end = buf
335 .iter()
336 .position(|&b| b == 0)
337 .ok_or_else(|| ProxyError::Protocol(
338 "unterminated cstring in protocol buffer".to_string(),
339 ))?;
340
341 let bytes = buf.split_to(end);
342 buf.advance(1); String::from_utf8(bytes.into())
345 .map_err(|e| ProxyError::Protocol(format!("Invalid UTF-8 in cstring: {}", e)))
346}
347
348fn write_cstring(buf: &mut BytesMut, s: &str) {
350 buf.extend_from_slice(s.as_bytes());
351 buf.put_u8(0);
352}
353
354#[derive(Debug, Clone)]
356pub struct QueryMessage {
357 pub query: String,
358}
359
360impl QueryMessage {
361 pub fn parse(mut payload: BytesMut) -> Result<Self> {
363 let query = read_cstring(&mut payload)?;
364 Ok(Self { query })
365 }
366
367 pub fn encode(&self) -> Message {
369 let mut payload = BytesMut::new();
370 write_cstring(&mut payload, &self.query);
371 Message::new(MessageType::Query, payload)
372 }
373}
374
375#[derive(Debug, Clone)]
377pub struct ParseMessage {
378 pub name: String,
379 pub query: String,
380 pub param_types: Vec<u32>,
381}
382
383impl ParseMessage {
384 pub fn parse(mut payload: BytesMut) -> Result<Self> {
386 let name = read_cstring(&mut payload)?;
387 let query = read_cstring(&mut payload)?;
388
389 let num_params = payload.get_u16() as usize;
390 let mut param_types = Vec::with_capacity(num_params);
391
392 for _ in 0..num_params {
393 param_types.push(payload.get_u32());
394 }
395
396 Ok(Self {
397 name,
398 query,
399 param_types,
400 })
401 }
402
403 pub fn encode(&self) -> Message {
405 let mut payload = BytesMut::new();
406 write_cstring(&mut payload, &self.name);
407 write_cstring(&mut payload, &self.query);
408 payload.put_u16(self.param_types.len() as u16);
409 for &t in &self.param_types {
410 payload.put_u32(t);
411 }
412 Message::new(MessageType::Parse, payload)
413 }
414}
415
416#[derive(Debug, Clone)]
422pub struct BindMessage {
423 pub portal: String,
424 pub statement: String,
425 pub param_formats: Vec<i16>,
426 pub param_values: Vec<Option<Bytes>>,
427 pub result_formats: Vec<i16>,
428}
429
430impl BindMessage {
431 pub fn parse(mut payload: BytesMut) -> Result<Self> {
433 let portal = read_cstring(&mut payload)?;
434 let statement = read_cstring(&mut payload)?;
435
436 let num_formats = payload.get_u16() as usize;
438 let mut param_formats = Vec::with_capacity(num_formats);
439 for _ in 0..num_formats {
440 param_formats.push(payload.get_i16());
441 }
442
443 let num_values = payload.get_u16() as usize;
447 let mut param_values = Vec::with_capacity(num_values);
448 for _ in 0..num_values {
449 let len = payload.get_i32();
450 if len == -1 {
451 param_values.push(None);
452 } else {
453 let value = payload.split_to(len as usize).freeze();
454 param_values.push(Some(value));
455 }
456 }
457
458 let num_result_formats = payload.get_u16() as usize;
460 let mut result_formats = Vec::with_capacity(num_result_formats);
461 for _ in 0..num_result_formats {
462 result_formats.push(payload.get_i16());
463 }
464
465 Ok(Self {
466 portal,
467 statement,
468 param_formats,
469 param_values,
470 result_formats,
471 })
472 }
473}
474
475#[derive(Debug, Clone)]
477pub struct ExecuteMessage {
478 pub portal: String,
479 pub max_rows: i32,
480}
481
482impl ExecuteMessage {
483 pub fn parse(mut payload: BytesMut) -> Result<Self> {
485 let portal = read_cstring(&mut payload)?;
486 let max_rows = payload.get_i32();
487 Ok(Self { portal, max_rows })
488 }
489
490 pub fn encode(&self) -> Message {
492 let mut payload = BytesMut::new();
493 write_cstring(&mut payload, &self.portal);
494 payload.put_i32(self.max_rows);
495 Message::new(MessageType::Execute, payload)
496 }
497}
498
499#[derive(Debug, Clone)]
501pub struct ErrorResponse {
502 pub fields: HashMap<char, String>,
503}
504
505impl ErrorResponse {
506 pub fn parse(mut payload: BytesMut) -> Result<Self> {
508 let mut fields = HashMap::new();
509
510 while payload.has_remaining() {
511 let code = payload.get_u8();
512 if code == 0 {
513 break;
514 }
515 let value = read_cstring(&mut payload)?;
516 fields.insert(code as char, value);
517 }
518
519 Ok(Self { fields })
520 }
521
522 pub fn severity(&self) -> Option<&str> {
524 self.fields.get(&'S').map(|s| s.as_str())
525 }
526
527 pub fn code(&self) -> Option<&str> {
529 self.fields.get(&'C').map(|s| s.as_str())
530 }
531
532 pub fn message(&self) -> Option<&str> {
534 self.fields.get(&'M').map(|s| s.as_str())
535 }
536
537 pub fn encode(&self) -> Message {
539 let mut payload = BytesMut::new();
540 for (&code, value) in &self.fields {
541 payload.put_u8(code as u8);
542 write_cstring(&mut payload, value);
543 }
544 payload.put_u8(0);
545 Message::new(MessageType::ErrorResponse, payload)
546 }
547}
548
549#[derive(Debug, Clone, Copy, PartialEq, Eq)]
551pub enum TransactionStatus {
552 Idle,
554 InTransaction,
556 Failed,
558}
559
560impl TransactionStatus {
561 pub fn from_byte(b: u8) -> Self {
563 match b {
564 b'I' => TransactionStatus::Idle,
565 b'T' => TransactionStatus::InTransaction,
566 b'E' => TransactionStatus::Failed,
567 _ => TransactionStatus::Idle,
568 }
569 }
570
571 pub fn to_byte(&self) -> u8 {
573 match self {
574 TransactionStatus::Idle => b'I',
575 TransactionStatus::InTransaction => b'T',
576 TransactionStatus::Failed => b'E',
577 }
578 }
579}
580
581#[derive(Debug, Clone)]
583pub struct CommandComplete {
584 pub tag: String,
585}
586
587impl CommandComplete {
588 pub fn parse(mut payload: BytesMut) -> Result<Self> {
590 let tag = read_cstring(&mut payload)?;
591 Ok(Self { tag })
592 }
593
594 pub fn encode(&self) -> Message {
596 let mut payload = BytesMut::new();
597 write_cstring(&mut payload, &self.tag);
598 Message::new(MessageType::CommandComplete, payload)
599 }
600
601 pub fn rows_affected(&self) -> Option<u64> {
603 let parts: Vec<&str> = self.tag.split_whitespace().collect();
604 if parts.len() >= 2 {
605 parts.last()?.parse().ok()
606 } else {
607 None
608 }
609 }
610}
611
612#[derive(Debug, Clone)]
614pub enum AuthRequest {
615 Ok,
617 CleartextPassword,
619 Md5Password { salt: [u8; 4] },
621 SASL { mechanisms: Vec<String> },
623 SASLContinue { data: Vec<u8> },
625 SASLFinal { data: Vec<u8> },
627 Unknown(i32),
629}
630
631impl AuthRequest {
632 pub fn parse(mut payload: BytesMut) -> Result<Self> {
634 let auth_type = payload.get_i32();
635
636 Ok(match auth_type {
637 0 => AuthRequest::Ok,
638 3 => AuthRequest::CleartextPassword,
639 5 => {
640 let mut salt = [0u8; 4];
641 payload.copy_to_slice(&mut salt);
642 AuthRequest::Md5Password { salt }
643 }
644 10 => {
645 let mut mechanisms = Vec::new();
646 loop {
647 let mech = read_cstring(&mut payload)?;
648 if mech.is_empty() {
649 break;
650 }
651 mechanisms.push(mech);
652 }
653 AuthRequest::SASL { mechanisms }
654 }
655 11 => {
656 let data = payload.to_vec();
657 AuthRequest::SASLContinue { data }
658 }
659 12 => {
660 let data = payload.to_vec();
661 AuthRequest::SASLFinal { data }
662 }
663 _ => AuthRequest::Unknown(auth_type),
664 })
665 }
666
667 pub fn encode(&self) -> Message {
669 let mut payload = BytesMut::new();
670
671 match self {
672 AuthRequest::Ok => {
673 payload.put_i32(0);
674 }
675 AuthRequest::CleartextPassword => {
676 payload.put_i32(3);
677 }
678 AuthRequest::Md5Password { salt } => {
679 payload.put_i32(5);
680 payload.extend_from_slice(salt);
681 }
682 AuthRequest::SASL { mechanisms } => {
683 payload.put_i32(10);
684 for mech in mechanisms {
685 write_cstring(&mut payload, mech);
686 }
687 payload.put_u8(0);
688 }
689 AuthRequest::SASLContinue { data } => {
690 payload.put_i32(11);
691 payload.extend_from_slice(data);
692 }
693 AuthRequest::SASLFinal { data } => {
694 payload.put_i32(12);
695 payload.extend_from_slice(data);
696 }
697 AuthRequest::Unknown(t) => {
698 payload.put_i32(*t);
699 }
700 }
701
702 Message::new(MessageType::AuthRequest, payload)
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709
710 #[test]
711 fn test_message_type_round_trip() {
712 let types = vec![
713 MessageType::Query,
714 MessageType::Parse,
715 MessageType::Bind,
716 MessageType::Execute,
717 MessageType::Sync,
718 ];
719
720 for msg_type in types {
721 if let Some(tag) = msg_type.to_tag() {
722 let decoded = MessageType::from_tag(tag);
723 assert_eq!(decoded, msg_type);
724 }
725 }
726 }
727
728 #[test]
729 fn test_query_message() {
730 let query = QueryMessage {
731 query: "SELECT 1".to_string(),
732 };
733 let msg = query.encode();
734 assert_eq!(msg.msg_type, MessageType::Query);
735
736 let decoded = QueryMessage::parse(msg.payload).unwrap();
737 assert_eq!(decoded.query, "SELECT 1");
738 }
739
740 #[test]
741 fn test_error_response() {
742 let mut fields = HashMap::new();
743 fields.insert('S', "ERROR".to_string());
744 fields.insert('C', "42P01".to_string());
745 fields.insert('M', "relation does not exist".to_string());
746
747 let err = ErrorResponse { fields };
748 assert_eq!(err.severity(), Some("ERROR"));
749 assert_eq!(err.code(), Some("42P01"));
750 assert_eq!(err.message(), Some("relation does not exist"));
751 }
752
753 #[test]
754 fn test_command_complete() {
755 let cmd = CommandComplete {
756 tag: "INSERT 0 5".to_string(),
757 };
758 assert_eq!(cmd.rows_affected(), Some(5));
759
760 let cmd2 = CommandComplete {
761 tag: "SELECT 100".to_string(),
762 };
763 assert_eq!(cmd2.rows_affected(), Some(100));
764 }
765
766 #[test]
767 fn test_transaction_status() {
768 assert_eq!(TransactionStatus::from_byte(b'I'), TransactionStatus::Idle);
769 assert_eq!(
770 TransactionStatus::from_byte(b'T'),
771 TransactionStatus::InTransaction
772 );
773 assert_eq!(TransactionStatus::from_byte(b'E'), TransactionStatus::Failed);
774
775 assert_eq!(TransactionStatus::Idle.to_byte(), b'I');
776 assert_eq!(TransactionStatus::InTransaction.to_byte(), b'T');
777 assert_eq!(TransactionStatus::Failed.to_byte(), b'E');
778 }
779
780 #[test]
781 fn test_protocol_codec() {
782 let codec = ProtocolCodec::new();
783 let query = QueryMessage {
784 query: "SELECT 1".to_string(),
785 };
786 let msg = query.encode();
787 let encoded = codec.encode_message(&msg);
788
789 assert!(encoded.len() > 5);
790 assert_eq!(encoded[0], b'Q');
791 }
792
793 #[test]
797 fn test_read_cstring_unterminated() {
798 let mut buf = BytesMut::from("not-null-terminated");
799 let err = read_cstring(&mut buf).expect_err("should reject unterminated cstring");
800 assert!(
801 matches!(err, ProxyError::Protocol(_)),
802 "expected Protocol error, got {err:?}"
803 );
804 }
805
806 #[test]
809 fn test_read_cstring_sequence() {
810 let mut buf = BytesMut::new();
811 buf.put_slice(b"first\0second\0tail");
812 let a = read_cstring(&mut buf).unwrap();
813 let b = read_cstring(&mut buf).unwrap();
814 assert_eq!(a, "first");
815 assert_eq!(b, "second");
816 assert_eq!(&buf[..], b"tail");
817 }
818
819 #[test]
823 fn test_bind_message_param_values_are_bytes() {
824 let mut payload = BytesMut::new();
825 payload.put_u8(0);
827 payload.put_u8(0);
828 payload.put_u16(1);
830 payload.put_i16(0);
831 payload.put_u16(2);
833 payload.put_i32(2);
834 payload.put_slice(b"hi");
835 payload.put_i32(-1);
836 payload.put_u16(0);
838
839 let bind = BindMessage::parse(payload).expect("parse failed");
840 assert_eq!(bind.param_values.len(), 2);
841 match &bind.param_values[0] {
842 Some(b) => assert_eq!(b.as_ref(), b"hi"),
843 None => panic!("first param must be Some"),
844 }
845 assert!(bind.param_values[1].is_none());
846 }
847}