1use async_trait::async_trait;
4use bytes::{BufMut, BytesMut};
5use std::{
6 convert::TryFrom,
7 fmt::Debug,
8 io::{Cursor, Error, ErrorKind},
9 marker::Send,
10 sync::Arc,
11};
12
13use crate::{
14 protocol::{ErrorCode, ErrorResponse},
15 ProtocolError,
16};
17use log::trace;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19
20use super::protocol::{self, Deserialize, FrontendMessage, Serialize};
21
22#[async_trait]
23pub trait MessageTagParser: Sync + Send + Debug {
24 async fn parse(
25 &self,
26 tag: u8,
27 cursor: Cursor<Vec<u8>>,
28 ) -> Result<FrontendMessage, ProtocolError>;
29}
30
31#[derive(Default, Debug)]
32pub struct MessageTagParserDefaultImpl {}
33
34impl MessageTagParserDefaultImpl {
35 pub fn new() -> Self {
36 Self {}
37 }
38
39 pub fn with_arc() -> Arc<dyn MessageTagParser> {
40 Arc::new(Self::new())
41 }
42}
43
44#[async_trait]
45impl MessageTagParser for MessageTagParserDefaultImpl {
46 async fn parse(
47 &self,
48 tag: u8,
49 cursor: Cursor<Vec<u8>>,
50 ) -> Result<FrontendMessage, ProtocolError> {
51 let message = match tag {
52 b'Q' => FrontendMessage::Query(protocol::Query::deserialize(cursor).await?),
53 b'P' => FrontendMessage::Parse(protocol::Parse::deserialize(cursor).await?),
54 b'B' => FrontendMessage::Bind(protocol::Bind::deserialize(cursor).await?),
55 b'D' => FrontendMessage::Describe(protocol::Describe::deserialize(cursor).await?),
56 b'E' => FrontendMessage::Execute(protocol::Execute::deserialize(cursor).await?),
57 b'C' => FrontendMessage::Close(protocol::Close::deserialize(cursor).await?),
58 b'p' => FrontendMessage::PasswordMessage(
59 protocol::PasswordMessage::deserialize(cursor).await?,
60 ),
61 b'X' => FrontendMessage::Terminate,
62 b'H' => FrontendMessage::Flush,
63 b'S' => FrontendMessage::Sync,
64 identifier => {
65 return Err(ErrorResponse::error(
66 ErrorCode::DataException,
67 format!("Unknown message identifier: {:X?}", identifier),
68 )
69 .into())
70 }
71 };
72 Ok(message)
73 }
74}
75
76pub async fn read_message<Reader: AsyncReadExt + Unpin + Send>(
77 reader: &mut Reader,
78 parser: Arc<dyn MessageTagParser>,
79) -> Result<FrontendMessage, ProtocolError> {
80 let message_tag = reader.read_u8().await?;
82 let cursor = read_contents(reader, message_tag).await?;
83 let message = parser.parse(message_tag, cursor).await?;
84
85 trace!("[pg] Decoded {:X?}", message,);
86
87 Ok(message)
88}
89
90pub async fn read_contents<Reader: AsyncReadExt + Unpin>(
91 reader: &mut Reader,
92 message_tag: u8,
93) -> Result<Cursor<Vec<u8>>, Error> {
94 let length = reader.read_u32().await?;
96 if length < 4 {
97 return Err(Error::other("Unexpectedly small (<0) message size"));
98 }
99
100 trace!(
101 "[pg] Receive package {:X?} with length {}",
102 message_tag,
103 length
104 );
105
106 let length = usize::try_from(length - 4).map_err(|_| {
107 Error::new(
108 ErrorKind::OutOfMemory,
109 "Unable to convert message length to a suitable memory size",
110 )
111 })?;
112
113 let buffer = if length == 0 {
114 vec![0; 0]
115 } else {
116 let mut buffer = vec![0; length];
117 reader.read_exact(&mut buffer).await?;
118
119 buffer
120 };
121
122 let cursor = Cursor::new(buffer);
123
124 Ok(cursor)
125}
126
127pub async fn read_string<Reader: AsyncReadExt + Unpin>(
128 reader: &mut Reader,
129) -> Result<String, Error> {
130 let mut bytes = Vec::with_capacity(64);
131
132 loop {
133 let byte = reader.read_u8().await?;
135 if byte == 0 {
136 break;
137 }
138
139 bytes.push(byte);
140 }
141
142 let string = String::from_utf8(bytes).map_err(|_| {
143 Error::new(
144 ErrorKind::InvalidData,
145 "Unable to parse bytes as a UTF-8 string",
146 )
147 })?;
148
149 Ok(string)
150}
151
152pub async fn read_format<Reader: AsyncReadExt + Unpin>(
153 reader: &mut Reader,
154) -> Result<protocol::Format, ProtocolError> {
155 match reader.read_i16().await? {
156 0 => Ok(protocol::Format::Text),
157 1 => Ok(protocol::Format::Binary),
158 format_code => Err(protocol::ErrorResponse::error(
159 protocol::ErrorCode::ProtocolViolation,
160 format!("Unknown format code: {}", format_code),
161 )
162 .into()),
163 }
164}
165
166pub async fn write_direct<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
168 partial_write: &mut BytesMut,
169 writer: &mut Writer,
170 message: Message,
171) -> Result<(), ProtocolError> {
172 let mut bytes_mut = BytesMut::new();
173 if let Some(buffer) = message.serialize() {
174 bytes_mut.extend_from_slice(&buffer);
176 *partial_write = bytes_mut;
177 writer.write_all_buf(partial_write).await?;
178 *partial_write = BytesMut::new();
179 writer.flush().await?;
180 }
181
182 Ok(())
183}
184
185fn message_serialize<Message: Serialize>(
186 message: Message,
187 packet_buffer: &mut BytesMut,
188) -> Result<(), ProtocolError> {
189 if message.code() != 0x00 {
190 packet_buffer.put_u8(message.code());
191 }
192
193 if let Some(buffer) = message.serialize() {
194 let size = u32::try_from(buffer.len() + 4).map_err(|_| {
195 ErrorResponse::error(
196 ErrorCode::InternalError,
197 "Unable to convert buffer length to a suitable memory size".to_string(),
198 )
199 })?;
200 packet_buffer.extend_from_slice(&size.to_be_bytes());
201 packet_buffer.extend_from_slice(&buffer);
202 }
203
204 Ok(())
205}
206
207pub async fn write_messages<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
211 partial_write: &mut BytesMut,
212 writer: &mut Writer,
213 messages: Vec<Message>,
214) -> Result<(), ProtocolError> {
215 let mut buffer = BytesMut::with_capacity(64 * messages.len());
216
217 for message in messages {
218 message_serialize(message, &mut buffer)?;
219 }
220
221 *partial_write = buffer;
225 writer.write_all_buf(partial_write).await?;
226 *partial_write = BytesMut::new();
227
228 writer.flush().await?;
232 Ok(())
233}
234
235pub async fn write_message<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
239 partial_write: &mut BytesMut,
240 writer: &mut Writer,
241 message: Message,
242) -> Result<(), ProtocolError> {
243 let mut buffer = BytesMut::with_capacity(64);
244 message_serialize(message, &mut buffer)?;
245
246 *partial_write = buffer;
247 writer.write_all_buf(partial_write).await?;
248 *partial_write = BytesMut::new();
249 writer.flush().await?;
250 Ok(())
251}
252
253pub fn write_string(buffer: &mut Vec<u8>, string: &str) {
254 buffer.extend_from_slice(string.as_bytes());
255 buffer.push(0);
256}