feophantlib/codec/
pg_codec.rs1use bytes::{Buf, Bytes, BytesMut};
4use std::convert::TryFrom;
5use tokio_util::codec::{Decoder, Encoder};
6
7use super::NetworkFrame;
8
9pub struct PgCodec {}
10
11impl Decoder for PgCodec {
12 type Item = NetworkFrame;
13 type Error = std::io::Error;
14
15 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
16 if src.len() < 5 {
17 return Ok(None);
19 }
20
21 debug!("Got message {:?}", src);
22
23 let mut message_bytes = [0u8; 1];
25 message_bytes.copy_from_slice(&src[..1]);
26 let message_type = u8::from_be(message_bytes[0]);
27
28 let prefix_len;
30 if message_type == 0 {
31 prefix_len = 4;
32 } else {
33 prefix_len = 5;
34 }
35 let mut length_bytes = [0u8; 4];
36 length_bytes.copy_from_slice(&src[(prefix_len - 4)..prefix_len]);
37
38 let length = u32::from_be_bytes(length_bytes) as u32;
39 if length < 4 {
40 return Err(std::io::Error::new(
41 std::io::ErrorKind::InvalidData,
42 format!("Frame length of {} is too small", length),
43 ));
44 }
45
46 let length_size = u32::from_be_bytes(length_bytes) as usize - 4;
47
48 if src.len() < prefix_len + length_size {
60 src.reserve(prefix_len + length_size - src.len());
65
66 return Ok(None);
69 }
70
71 let data = src[prefix_len..prefix_len + length_size].to_vec();
74 src.advance(prefix_len + length_size);
75
76 debug!("Got message type {:x} and payload {:?}", message_type, data);
77
78 Ok(Some(NetworkFrame::new(message_type, Bytes::from(data))))
80 }
81}
82
83impl Encoder<NetworkFrame> for PgCodec {
84 type Error = std::io::Error;
85
86 fn encode(&mut self, item: NetworkFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
87 debug!(
88 "Sending message type {:x} and payload {:?}",
89 item.message_type, item.payload
90 );
91
92 if item.message_type == 0 {
94 dst.reserve(item.payload.len());
96 } else {
97 dst.reserve(5 + item.payload.len());
99
100 dst.extend_from_slice(&[item.message_type][..]);
102
103 let length = match u32::try_from(item.payload.len() + 4) {
105 Ok(n) => n,
106 Err(_) => {
107 return Err(std::io::Error::new(
108 std::io::ErrorKind::InvalidData,
109 format!(
110 "Frame of length {} plus length header is too large.",
111 item.payload.len()
112 ),
113 ))
114 }
115 };
116
117 let len_slice = u32::to_be_bytes(length);
118 dst.extend_from_slice(&len_slice);
119 }
120 dst.extend_from_slice(&item.payload);
122
123 Ok(())
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::super::super::processor::ssl_and_gssapi_parser;
130 use super::*;
131 use hex_literal::hex;
132
133 #[test]
134 fn test_decode() {
135 let input = hex!("00 00 00 08 04 D2 16 2F");
136 let mut buf = BytesMut::new();
137 buf.extend_from_slice(&input);
138
139 let mut codec = PgCodec {};
140 let msg = codec.decode(&mut buf).unwrap().unwrap();
141
142 assert_eq!(msg.message_type, 0);
143 assert_eq!(ssl_and_gssapi_parser::is_ssl_request(&msg.payload), true);
144 }
145}