fern_protocol_postgresql/codec/
mod.rs

1// SPDX-FileCopyrightText:  Copyright © 2022 The Fern Authors <team@fernproxy.io>
2// SPDX-License-Identifier: Apache-2.0
3
4//! Adaptors from [`AsyncRead`]/[`AsyncWrite`] to [`Stream`]/[`Sink`],
5//! decoding/encoding PostgreSQL protocol version 3.
6//!
7//! [`AsyncRead`]: https://docs.rs/tokio/*/tokio/io/trait.AsyncRead.html
8//! [`AsyncWrite`]: https://docs.rs/tokio/latest/tokio/io/trait.AsyncWrite.html
9//! [`Stream`]: https://docs.rs/futures/*/futures/stream/trait.Stream.html
10//! [`Sink`]: https://docs.rs/futures/*/futures/sink/trait.Sink.html
11
12use fern_proxy_interfaces::SQLMessage;
13
14pub mod backend;
15pub mod frontend;
16
17/// A trait to abstract frontend and backend Messages.
18//TODO(ppiotr3k): make private again
19pub trait PostgresMessage: SQLMessage + std::fmt::Debug {}
20
21/// Collection of constants used internally.
22pub(crate) mod constants {
23    /// Amount of bytes for PostreSQL Message identifier.
24    pub const BYTES_MESSAGE_ID: usize = 1;
25
26    /// Amount of bytes for PostreSQL Message payload size, including self.
27    pub const BYTES_MESSAGE_SIZE: usize = 4;
28
29    /// Amount of bytes for PostreSQL Message header.
30    pub const BYTES_MESSAGE_HEADER: usize = BYTES_MESSAGE_ID + BYTES_MESSAGE_SIZE;
31}
32
33/// Collection of bytes manipulating helper functions used internally.
34pub(crate) mod utils {
35
36    use bytes::{Buf, BufMut, Bytes, BytesMut};
37    use std::io;
38
39    /// Creates a new [`Bytes`] instance by first reading a length header of
40    /// `length_bytes`, and then getting as many bytes.
41    ///
42    /// The current position in `buf` is advanced by `length_bytes` and
43    /// the value contained in the size header.
44    ///
45    /// This function is optimized to avoid copies by using a [`Bytes`]
46    /// implementation which only performs a shallow copy.
47    ///
48    /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
49    ///
50    /// [`Bytes`]: https://docs.rs/bytes/*/bytes/struct.Bytes.html
51    pub(crate) fn get_bytes(
52        buf: &mut BytesMut,
53        length_bytes: usize,
54        error_msg: &str,
55    ) -> io::Result<Bytes> {
56        // Shouldn't happend, unless packet is malformed.
57        if buf.remaining() < length_bytes {
58            let err = io::Error::new(
59                io::ErrorKind::UnexpectedEof,
60                "malformed packet - invalid data size",
61            );
62            log::error!("{}", err);
63            return Err(err);
64        }
65
66        let data_length = buf.get_u32();
67        log::trace!("bytes data length header: {}", data_length);
68
69        let data = if data_length == u32::MAX {
70            Bytes::new()
71        } else {
72            if buf.remaining() < data_length as usize {
73                log::error!("{}", error_msg);
74                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
75            }
76            buf.copy_to_bytes(data_length as usize)
77        };
78
79        Ok(data)
80    }
81
82    /// Writes bytes to `buf`, prefixing `data` with a size header.
83    ///
84    /// The current position in `buf` is advanced by the length of `data`,
85    /// and the 4 bytes required to define the size header.
86    ///
87    /// # Panics
88    ///
89    /// This function panics if there is not enough remaining capacity in `buf`.
90    pub(crate) fn put_bytes(data: &Bytes, buf: &mut BytesMut) {
91        buf.put_u32(data.len() as u32);
92        // Cloning `Bytes` is cheap, it is an `Arc` increment.
93        buf.put((*data).clone());
94    }
95
96    /// Gets a C-style null character terminated string from `buf`
97    /// as a new [`Bytes`] instance.
98    ///
99    /// The current position in `buf` is advanced by the length of the string,
100    /// and 1 byte to account for the null character terminator.
101    ///
102    /// This function is optimized to avoid copies by using a [`Bytes`]
103    /// implementation which only performs a shallow copy.
104    ///
105    /// Returns [`io::ErrorKind::InvalidData`] if no null chararacter terminator if found.
106    ///
107    /// [`Bytes`]: https://docs.rs/bytes/*/bytes/struct.Bytes.html
108    pub(crate) fn get_cstr(buf: &mut BytesMut) -> io::Result<Bytes> {
109        let nullchar_offset = buf[..].iter().position(|x| *x == b'\0');
110
111        match nullchar_offset {
112            None => {
113                let err = io::Error::new(
114                    io::ErrorKind::InvalidData,
115                    "malformed packet - cstr without null char terminator",
116                );
117                log::error!("{}", err);
118                Err(err)
119            }
120            Some(offset) => {
121                let str_bytes = buf.copy_to_bytes(offset);
122                buf.advance(1); // consume null char
123                Ok(str_bytes)
124            }
125        }
126    }
127
128    /// Writes a C-style null character terminated string to `buf`.
129    ///
130    /// The current position in `buf` is advanced by the length of the string,
131    /// and 1 byte to account for the null character terminator.
132    ///
133    /// # Note
134    ///
135    /// The [`Bytes`] instance defining the string is deemed to contain
136    /// only valid ASCII characters, as it is deemed to originate from a
137    /// controlled and/or sanitized context.
138    ///
139    /// # Panics
140    ///
141    /// This function panics if there is not enough remaining capacity in `buf`.
142    ///
143    /// [`Bytes`]: https://docs.rs/bytes/*/bytes/struct.Bytes.html
144    pub(crate) fn put_cstr(string: &Bytes, buf: &mut BytesMut) {
145        // Cloning `Bytes` is cheap, it is an `Arc` increment.
146        buf.put((*string).clone());
147        buf.put_u8(b'\0');
148    }
149
150    /// Gets an unsigned 32-bit integer from `buf` in big-endian byte order,
151    /// and advances current position by 4.
152    ///
153    /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
154    pub(crate) fn get_u32(buf: &mut BytesMut, error_msg: &str) -> io::Result<u32> {
155        if buf.remaining() < 4 {
156            log::error!("{}", error_msg);
157            return Err(io::Error::new(io::ErrorKind::InvalidData, error_msg));
158        }
159
160        let value = buf.get_u32();
161        Ok(value)
162    }
163
164    /// Gets a signed 32-bit integer from `buf` in big-endian byte order,
165    /// and advances current position by 4.
166    ///
167    /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
168    pub(crate) fn get_i32(buf: &mut BytesMut, error_msg: &str) -> io::Result<i32> {
169        if buf.remaining() < 4 {
170            log::error!("{}", error_msg);
171            return Err(io::Error::new(io::ErrorKind::InvalidData, error_msg));
172        }
173
174        let value = buf.get_i32();
175        Ok(value)
176    }
177
178    /// Gets an unsigned 16-bit integer from `buf` in big-endian byte order,
179    /// and advances current position by 2.
180    ///
181    /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
182    pub(crate) fn get_u16(buf: &mut BytesMut, error_msg: &str) -> io::Result<u16> {
183        if buf.remaining() < 2 {
184            log::error!("{}", error_msg);
185            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
186        }
187
188        let value = buf.get_u16();
189        Ok(value)
190    }
191
192    /// Gets a signed 16-bit integer from `buf` in big-endian byte order,
193    /// and advances current position by 2.
194    ///
195    /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
196    pub(crate) fn get_i16(buf: &mut BytesMut, error_msg: &str) -> io::Result<i16> {
197        if buf.remaining() < 2 {
198            log::error!("{}", error_msg);
199            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
200        }
201
202        let value = buf.get_i16();
203        Ok(value)
204    }
205
206    /// Gets an unsigned 8-bit integer from `buf` in big-endian byte order,
207    /// and advances current position by 1.
208    ///
209    /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
210    pub(crate) fn get_u8(buf: &mut BytesMut, error_msg: &str) -> io::Result<u8> {
211        if buf.remaining() < 1 {
212            log::error!("{}", error_msg);
213            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
214        }
215
216        let value = buf.get_u8();
217        Ok(value)
218    }
219}