Skip to main content

pg_srv/
buffer.rs

1//! Helpers for reading/writing from/to the connection's socket
2
3use async_trait::async_trait;
4use bytes::{BufMut, BytesMut};
5use std::{
6    convert::TryFrom,
7    fmt::Debug,
8    io::{Cursor, Error, ErrorKind},
9    marker::Send,
10    sync::Arc,
11};
12
13use crate::{
14    protocol::{ErrorCode, ErrorResponse},
15    ProtocolError,
16};
17use log::trace;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19
20use super::protocol::{self, Deserialize, FrontendMessage, Serialize};
21
22#[async_trait]
23pub trait MessageTagParser: Sync + Send + Debug {
24    async fn parse(
25        &self,
26        tag: u8,
27        cursor: Cursor<Vec<u8>>,
28    ) -> Result<FrontendMessage, ProtocolError>;
29}
30
31#[derive(Default, Debug)]
32pub struct MessageTagParserDefaultImpl {}
33
34impl MessageTagParserDefaultImpl {
35    pub fn new() -> Self {
36        Self {}
37    }
38
39    pub fn with_arc() -> Arc<dyn MessageTagParser> {
40        Arc::new(Self::new())
41    }
42}
43
44#[async_trait]
45impl MessageTagParser for MessageTagParserDefaultImpl {
46    async fn parse(
47        &self,
48        tag: u8,
49        cursor: Cursor<Vec<u8>>,
50    ) -> Result<FrontendMessage, ProtocolError> {
51        let message = match tag {
52            b'Q' => FrontendMessage::Query(protocol::Query::deserialize(cursor).await?),
53            b'P' => FrontendMessage::Parse(protocol::Parse::deserialize(cursor).await?),
54            b'B' => FrontendMessage::Bind(protocol::Bind::deserialize(cursor).await?),
55            b'D' => FrontendMessage::Describe(protocol::Describe::deserialize(cursor).await?),
56            b'E' => FrontendMessage::Execute(protocol::Execute::deserialize(cursor).await?),
57            b'C' => FrontendMessage::Close(protocol::Close::deserialize(cursor).await?),
58            b'p' => FrontendMessage::PasswordMessage(
59                protocol::PasswordMessage::deserialize(cursor).await?,
60            ),
61            b'X' => FrontendMessage::Terminate,
62            b'H' => FrontendMessage::Flush,
63            b'S' => FrontendMessage::Sync,
64            identifier => {
65                return Err(ErrorResponse::error(
66                    ErrorCode::DataException,
67                    format!("Unknown message identifier: {:X?}", identifier),
68                )
69                .into())
70            }
71        };
72        Ok(message)
73    }
74}
75
76pub async fn read_message<Reader: AsyncReadExt + Unpin + Send>(
77    reader: &mut Reader,
78    parser: Arc<dyn MessageTagParser>,
79) -> Result<FrontendMessage, ProtocolError> {
80    // https://www.postgresql.org/docs/14/protocol-message-formats.html
81    let message_tag = reader.read_u8().await?;
82    let cursor = read_contents(reader, message_tag).await?;
83    let message = parser.parse(message_tag, cursor).await?;
84
85    trace!("[pg] Decoded {:X?}", message,);
86
87    Ok(message)
88}
89
90pub async fn read_contents<Reader: AsyncReadExt + Unpin>(
91    reader: &mut Reader,
92    message_tag: u8,
93) -> Result<Cursor<Vec<u8>>, Error> {
94    // protocol defines length for all types of messages
95    let length = reader.read_u32().await?;
96    if length < 4 {
97        return Err(Error::other("Unexpectedly small (<0) message size"));
98    }
99
100    trace!(
101        "[pg] Receive package {:X?} with length {}",
102        message_tag,
103        length
104    );
105
106    let length = usize::try_from(length - 4).map_err(|_| {
107        Error::new(
108            ErrorKind::OutOfMemory,
109            "Unable to convert message length to a suitable memory size",
110        )
111    })?;
112
113    let buffer = if length == 0 {
114        vec![0; 0]
115    } else {
116        let mut buffer = vec![0; length];
117        reader.read_exact(&mut buffer).await?;
118
119        buffer
120    };
121
122    let cursor = Cursor::new(buffer);
123
124    Ok(cursor)
125}
126
127pub async fn read_string<Reader: AsyncReadExt + Unpin>(
128    reader: &mut Reader,
129) -> Result<String, Error> {
130    let mut bytes = Vec::with_capacity(64);
131
132    loop {
133        // PostgreSQL uses a null-terminated string (C-style string)
134        let byte = reader.read_u8().await?;
135        if byte == 0 {
136            break;
137        }
138
139        bytes.push(byte);
140    }
141
142    let string = String::from_utf8(bytes).map_err(|_| {
143        Error::new(
144            ErrorKind::InvalidData,
145            "Unable to parse bytes as a UTF-8 string",
146        )
147    })?;
148
149    Ok(string)
150}
151
152pub async fn read_format<Reader: AsyncReadExt + Unpin>(
153    reader: &mut Reader,
154) -> Result<protocol::Format, ProtocolError> {
155    match reader.read_i16().await? {
156        0 => Ok(protocol::Format::Text),
157        1 => Ok(protocol::Format::Binary),
158        format_code => Err(protocol::ErrorResponse::error(
159            protocol::ErrorCode::ProtocolViolation,
160            format!("Unknown format code: {}", format_code),
161        )
162        .into()),
163    }
164}
165
166/// Same as the write_message function, but it doesn’t append header for frame (code + size).
167pub async fn write_direct<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
168    partial_write: &mut BytesMut,
169    writer: &mut Writer,
170    message: Message,
171) -> Result<(), ProtocolError> {
172    let mut bytes_mut = BytesMut::new();
173    if let Some(buffer) = message.serialize() {
174        // TODO: Yet another memory copy.
175        bytes_mut.extend_from_slice(&buffer);
176        *partial_write = bytes_mut;
177        writer.write_all_buf(partial_write).await?;
178        *partial_write = BytesMut::new();
179        writer.flush().await?;
180    }
181
182    Ok(())
183}
184
185fn message_serialize<Message: Serialize>(
186    message: Message,
187    packet_buffer: &mut BytesMut,
188) -> Result<(), ProtocolError> {
189    if message.code() != 0x00 {
190        packet_buffer.put_u8(message.code());
191    }
192
193    if let Some(buffer) = message.serialize() {
194        let size = u32::try_from(buffer.len() + 4).map_err(|_| {
195            ErrorResponse::error(
196                ErrorCode::InternalError,
197                "Unable to convert buffer length to a suitable memory size".to_string(),
198            )
199        })?;
200        packet_buffer.extend_from_slice(&size.to_be_bytes());
201        packet_buffer.extend_from_slice(&buffer);
202    }
203
204    Ok(())
205}
206
207/// Write multiple F messages with frame's headers to the writer.  The variable
208/// `*partial_write` is set for graceful shutdown attempts with partial writes.
209/// Upon a successful write, it is left empty.
210pub async fn write_messages<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
211    partial_write: &mut BytesMut,
212    writer: &mut Writer,
213    messages: Vec<Message>,
214) -> Result<(), ProtocolError> {
215    let mut buffer = BytesMut::with_capacity(64 * messages.len());
216
217    for message in messages {
218        message_serialize(message, &mut buffer)?;
219    }
220
221    // For simplicity we obviously don't save message boundary data with
222    // `*partial_write`, which means that a AdminShutdown fatal error message
223    // would have to be written after _all_ these messages.
224    *partial_write = buffer;
225    writer.write_all_buf(partial_write).await?;
226    *partial_write = BytesMut::new();
227
228    // (We _could_ reuse the buffer in *partial_write, doing fewer allocations -- after
229    // making other serialization logic allocate less and thinking about memory usage.)
230
231    writer.flush().await?;
232    Ok(())
233}
234
235/// Write single F message with frame's headers to the writer.  As with the
236/// function `write_messages`, `*partial_write` is set for graceful shutdown
237/// attempts with partial writes.  Upon a successful write, it is left empty.
238pub async fn write_message<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
239    partial_write: &mut BytesMut,
240    writer: &mut Writer,
241    message: Message,
242) -> Result<(), ProtocolError> {
243    let mut buffer = BytesMut::with_capacity(64);
244    message_serialize(message, &mut buffer)?;
245
246    *partial_write = buffer;
247    writer.write_all_buf(partial_write).await?;
248    *partial_write = BytesMut::new();
249    writer.flush().await?;
250    Ok(())
251}
252
253pub fn write_string(buffer: &mut Vec<u8>, string: &str) {
254    buffer.extend_from_slice(string.as_bytes());
255    buffer.push(0);
256}