fern_protocol_postgresql/codec/mod.rs
1// SPDX-FileCopyrightText: Copyright © 2022 The Fern Authors <team@fernproxy.io>
2// SPDX-License-Identifier: Apache-2.0
3
4//! Adaptors from [`AsyncRead`]/[`AsyncWrite`] to [`Stream`]/[`Sink`],
5//! decoding/encoding PostgreSQL protocol version 3.
6//!
7//! [`AsyncRead`]: https://docs.rs/tokio/*/tokio/io/trait.AsyncRead.html
8//! [`AsyncWrite`]: https://docs.rs/tokio/latest/tokio/io/trait.AsyncWrite.html
9//! [`Stream`]: https://docs.rs/futures/*/futures/stream/trait.Stream.html
10//! [`Sink`]: https://docs.rs/futures/*/futures/sink/trait.Sink.html
11
12use fern_proxy_interfaces::SQLMessage;
13
14pub mod backend;
15pub mod frontend;
16
17/// A trait to abstract frontend and backend Messages.
18//TODO(ppiotr3k): make private again
19pub trait PostgresMessage: SQLMessage + std::fmt::Debug {}
20
21/// Collection of constants used internally.
22pub(crate) mod constants {
23 /// Amount of bytes for PostreSQL Message identifier.
24 pub const BYTES_MESSAGE_ID: usize = 1;
25
26 /// Amount of bytes for PostreSQL Message payload size, including self.
27 pub const BYTES_MESSAGE_SIZE: usize = 4;
28
29 /// Amount of bytes for PostreSQL Message header.
30 pub const BYTES_MESSAGE_HEADER: usize = BYTES_MESSAGE_ID + BYTES_MESSAGE_SIZE;
31}
32
33/// Collection of bytes manipulating helper functions used internally.
34pub(crate) mod utils {
35
36 use bytes::{Buf, BufMut, Bytes, BytesMut};
37 use std::io;
38
39 /// Creates a new [`Bytes`] instance by first reading a length header of
40 /// `length_bytes`, and then getting as many bytes.
41 ///
42 /// The current position in `buf` is advanced by `length_bytes` and
43 /// the value contained in the size header.
44 ///
45 /// This function is optimized to avoid copies by using a [`Bytes`]
46 /// implementation which only performs a shallow copy.
47 ///
48 /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
49 ///
50 /// [`Bytes`]: https://docs.rs/bytes/*/bytes/struct.Bytes.html
51 pub(crate) fn get_bytes(
52 buf: &mut BytesMut,
53 length_bytes: usize,
54 error_msg: &str,
55 ) -> io::Result<Bytes> {
56 // Shouldn't happend, unless packet is malformed.
57 if buf.remaining() < length_bytes {
58 let err = io::Error::new(
59 io::ErrorKind::UnexpectedEof,
60 "malformed packet - invalid data size",
61 );
62 log::error!("{}", err);
63 return Err(err);
64 }
65
66 let data_length = buf.get_u32();
67 log::trace!("bytes data length header: {}", data_length);
68
69 let data = if data_length == u32::MAX {
70 Bytes::new()
71 } else {
72 if buf.remaining() < data_length as usize {
73 log::error!("{}", error_msg);
74 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
75 }
76 buf.copy_to_bytes(data_length as usize)
77 };
78
79 Ok(data)
80 }
81
82 /// Writes bytes to `buf`, prefixing `data` with a size header.
83 ///
84 /// The current position in `buf` is advanced by the length of `data`,
85 /// and the 4 bytes required to define the size header.
86 ///
87 /// # Panics
88 ///
89 /// This function panics if there is not enough remaining capacity in `buf`.
90 pub(crate) fn put_bytes(data: &Bytes, buf: &mut BytesMut) {
91 buf.put_u32(data.len() as u32);
92 // Cloning `Bytes` is cheap, it is an `Arc` increment.
93 buf.put((*data).clone());
94 }
95
96 /// Gets a C-style null character terminated string from `buf`
97 /// as a new [`Bytes`] instance.
98 ///
99 /// The current position in `buf` is advanced by the length of the string,
100 /// and 1 byte to account for the null character terminator.
101 ///
102 /// This function is optimized to avoid copies by using a [`Bytes`]
103 /// implementation which only performs a shallow copy.
104 ///
105 /// Returns [`io::ErrorKind::InvalidData`] if no null chararacter terminator if found.
106 ///
107 /// [`Bytes`]: https://docs.rs/bytes/*/bytes/struct.Bytes.html
108 pub(crate) fn get_cstr(buf: &mut BytesMut) -> io::Result<Bytes> {
109 let nullchar_offset = buf[..].iter().position(|x| *x == b'\0');
110
111 match nullchar_offset {
112 None => {
113 let err = io::Error::new(
114 io::ErrorKind::InvalidData,
115 "malformed packet - cstr without null char terminator",
116 );
117 log::error!("{}", err);
118 Err(err)
119 }
120 Some(offset) => {
121 let str_bytes = buf.copy_to_bytes(offset);
122 buf.advance(1); // consume null char
123 Ok(str_bytes)
124 }
125 }
126 }
127
128 /// Writes a C-style null character terminated string to `buf`.
129 ///
130 /// The current position in `buf` is advanced by the length of the string,
131 /// and 1 byte to account for the null character terminator.
132 ///
133 /// # Note
134 ///
135 /// The [`Bytes`] instance defining the string is deemed to contain
136 /// only valid ASCII characters, as it is deemed to originate from a
137 /// controlled and/or sanitized context.
138 ///
139 /// # Panics
140 ///
141 /// This function panics if there is not enough remaining capacity in `buf`.
142 ///
143 /// [`Bytes`]: https://docs.rs/bytes/*/bytes/struct.Bytes.html
144 pub(crate) fn put_cstr(string: &Bytes, buf: &mut BytesMut) {
145 // Cloning `Bytes` is cheap, it is an `Arc` increment.
146 buf.put((*string).clone());
147 buf.put_u8(b'\0');
148 }
149
150 /// Gets an unsigned 32-bit integer from `buf` in big-endian byte order,
151 /// and advances current position by 4.
152 ///
153 /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
154 pub(crate) fn get_u32(buf: &mut BytesMut, error_msg: &str) -> io::Result<u32> {
155 if buf.remaining() < 4 {
156 log::error!("{}", error_msg);
157 return Err(io::Error::new(io::ErrorKind::InvalidData, error_msg));
158 }
159
160 let value = buf.get_u32();
161 Ok(value)
162 }
163
164 /// Gets a signed 32-bit integer from `buf` in big-endian byte order,
165 /// and advances current position by 4.
166 ///
167 /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
168 pub(crate) fn get_i32(buf: &mut BytesMut, error_msg: &str) -> io::Result<i32> {
169 if buf.remaining() < 4 {
170 log::error!("{}", error_msg);
171 return Err(io::Error::new(io::ErrorKind::InvalidData, error_msg));
172 }
173
174 let value = buf.get_i32();
175 Ok(value)
176 }
177
178 /// Gets an unsigned 16-bit integer from `buf` in big-endian byte order,
179 /// and advances current position by 2.
180 ///
181 /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
182 pub(crate) fn get_u16(buf: &mut BytesMut, error_msg: &str) -> io::Result<u16> {
183 if buf.remaining() < 2 {
184 log::error!("{}", error_msg);
185 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
186 }
187
188 let value = buf.get_u16();
189 Ok(value)
190 }
191
192 /// Gets a signed 16-bit integer from `buf` in big-endian byte order,
193 /// and advances current position by 2.
194 ///
195 /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
196 pub(crate) fn get_i16(buf: &mut BytesMut, error_msg: &str) -> io::Result<i16> {
197 if buf.remaining() < 2 {
198 log::error!("{}", error_msg);
199 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
200 }
201
202 let value = buf.get_i16();
203 Ok(value)
204 }
205
206 /// Gets an unsigned 8-bit integer from `buf` in big-endian byte order,
207 /// and advances current position by 1.
208 ///
209 /// Returns [`io::ErrorKind::UnexpectedEof`] with `error_msg` if there is not enough data.
210 pub(crate) fn get_u8(buf: &mut BytesMut, error_msg: &str) -> io::Result<u8> {
211 if buf.remaining() < 1 {
212 log::error!("{}", error_msg);
213 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, error_msg));
214 }
215
216 let value = buf.get_u8();
217 Ok(value)
218 }
219}