feophantlib/codec/
pg_codec.rs

1//! Implementation hints from here: https://docs.rs/tokio-util/0.6.6/tokio_util/codec/index.html
2
3use bytes::{Buf, Bytes, BytesMut};
4use std::convert::TryFrom;
5use tokio_util::codec::{Decoder, Encoder};
6
7use super::NetworkFrame;
8
9pub struct PgCodec {}
10
11impl Decoder for PgCodec {
12    type Item = NetworkFrame;
13    type Error = std::io::Error;
14
15    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
16        if src.len() < 5 {
17            // Not enough data to make a decision.
18            return Ok(None);
19        }
20
21        debug!("Got message {:?}", src);
22
23        //Read the first byte
24        let mut message_bytes = [0u8; 1];
25        message_bytes.copy_from_slice(&src[..1]);
26        let message_type = u8::from_be(message_bytes[0]);
27
28        //If the message_type is 0, then it doesn't have a type and should just be seen as the length
29        let prefix_len;
30        if message_type == 0 {
31            prefix_len = 4;
32        } else {
33            prefix_len = 5;
34        }
35        let mut length_bytes = [0u8; 4];
36        length_bytes.copy_from_slice(&src[(prefix_len - 4)..prefix_len]);
37
38        let length = u32::from_be_bytes(length_bytes) as u32;
39        if length < 4 {
40            return Err(std::io::Error::new(
41                std::io::ErrorKind::InvalidData,
42                format!("Frame length of {} is too small", length),
43            ));
44        }
45
46        let length_size = u32::from_be_bytes(length_bytes) as usize - 4;
47
48        // TODO - Unsure how to stop DDOS when the protocol allows up to 2GB of data
49        //          Would be great to know if the user is authenticated
50        // Check that the length is not too large to avoid a denial of
51        // service attack where the server runs out of memory.
52        //if length > MAX {
53        //    return Err(std::io::Error::new(
54        //        std::io::ErrorKind::InvalidData,
55        //        format!("Frame of length {} is too large.", length)
56        //    ));
57        //}
58
59        if src.len() < prefix_len + length_size {
60            // The full payload has not yet arrived.
61            //
62            // We reserve more space in the buffer. This is not strictly
63            // necessary, but is a good idea performance-wise.
64            src.reserve(prefix_len + length_size - src.len());
65
66            // We inform the Framed that we need more bytes to form the next
67            // frame.
68            return Ok(None);
69        }
70
71        // Use advance to modify src such that it no longer contains
72        // this frame.
73        let data = src[prefix_len..prefix_len + length_size].to_vec();
74        src.advance(prefix_len + length_size);
75
76        debug!("Got message type {:x} and payload {:?}", message_type, data);
77
78        // Convert the data to a string, or fail if it is not valid utf-8.
79        Ok(Some(NetworkFrame::new(message_type, Bytes::from(data))))
80    }
81}
82
83impl Encoder<NetworkFrame> for PgCodec {
84    type Error = std::io::Error;
85
86    fn encode(&mut self, item: NetworkFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
87        debug!(
88            "Sending message type {:x} and payload {:?}",
89            item.message_type, item.payload
90        );
91
92        //Messages types of zero are special because they get written out raw. Probably should find a better way to do this
93        if item.message_type == 0 {
94            // Reserve space in the buffer.
95            dst.reserve(item.payload.len());
96        } else {
97            // Reserve space in the buffer.
98            dst.reserve(5 + item.payload.len());
99
100            //Enter the type
101            dst.extend_from_slice(&[item.message_type][..]);
102
103            // Convert the length into a byte array.
104            let length = match u32::try_from(item.payload.len() + 4) {
105                Ok(n) => n,
106                Err(_) => {
107                    return Err(std::io::Error::new(
108                        std::io::ErrorKind::InvalidData,
109                        format!(
110                            "Frame of length {} plus length header is too large.",
111                            item.payload.len()
112                        ),
113                    ))
114                }
115            };
116
117            let len_slice = u32::to_be_bytes(length);
118            dst.extend_from_slice(&len_slice);
119        }
120        //Write to Buffer
121        dst.extend_from_slice(&item.payload);
122
123        Ok(())
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::super::super::processor::ssl_and_gssapi_parser;
130    use super::*;
131    use hex_literal::hex;
132
133    #[test]
134    fn test_decode() {
135        let input = hex!("00 00 00 08 04 D2 16 2F");
136        let mut buf = BytesMut::new();
137        buf.extend_from_slice(&input);
138
139        let mut codec = PgCodec {};
140        let msg = codec.decode(&mut buf).unwrap().unwrap();
141
142        assert_eq!(msg.message_type, 0);
143        assert_eq!(ssl_and_gssapi_parser::is_ssl_request(&msg.payload), true);
144    }
145}