use bytes::{BufMut, BytesMut};
use std::{
convert::TryFrom,
io::{Cursor, Error, ErrorKind},
marker::Send,
};
use crate::{
protocol::{ErrorCode, ErrorResponse},
ProtocolError,
};
use log::trace;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::protocol::{self, Deserialize, FrontendMessage, Serialize};
pub async fn read_message<Reader: AsyncReadExt + Unpin + Send>(
reader: &mut Reader,
) -> Result<FrontendMessage, ProtocolError> {
let message_tag = reader.read_u8().await?;
let cursor = read_contents(reader, message_tag).await?;
let message = match message_tag {
b'Q' => FrontendMessage::Query(protocol::Query::deserialize(cursor).await?),
b'P' => FrontendMessage::Parse(protocol::Parse::deserialize(cursor).await?),
b'B' => FrontendMessage::Bind(protocol::Bind::deserialize(cursor).await?),
b'D' => FrontendMessage::Describe(protocol::Describe::deserialize(cursor).await?),
b'E' => FrontendMessage::Execute(protocol::Execute::deserialize(cursor).await?),
b'C' => FrontendMessage::Close(protocol::Close::deserialize(cursor).await?),
b'p' => {
FrontendMessage::PasswordMessage(protocol::PasswordMessage::deserialize(cursor).await?)
}
b'X' => FrontendMessage::Terminate,
b'H' => FrontendMessage::Flush,
b'S' => FrontendMessage::Sync,
identifier => {
return Err(ErrorResponse::error(
ErrorCode::DataException,
format!("Unknown message identifier: {:X?}", identifier),
)
.into())
}
};
trace!("[pg] Decoded {:X?}", message,);
Ok(message)
}
pub async fn read_contents<Reader: AsyncReadExt + Unpin>(
reader: &mut Reader,
message_tag: u8,
) -> Result<Cursor<Vec<u8>>, Error> {
let length = reader.read_u32().await?;
if length < 4 {
return Err(Error::new(
ErrorKind::Other,
"Unexpectedly small (<0) message size",
));
}
trace!(
"[pg] Receive package {:X?} with length {}",
message_tag,
length
);
let length = usize::try_from(length - 4).map_err(|_| {
Error::new(
ErrorKind::OutOfMemory,
"Unable to convert message length to a suitable memory size",
)
})?;
let buffer = if length == 0 {
vec![0; 0]
} else {
let mut buffer = vec![0; length];
reader.read_exact(&mut buffer).await?;
buffer
};
let cursor = Cursor::new(buffer);
Ok(cursor)
}
pub async fn read_string<Reader: AsyncReadExt + Unpin>(
reader: &mut Reader,
) -> Result<String, Error> {
let mut bytes = Vec::with_capacity(64);
loop {
let byte = reader.read_u8().await?;
if byte == 0 {
break;
}
bytes.push(byte);
}
let string = String::from_utf8(bytes).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
"Unable to parse bytes as a UTF-8 string",
)
})?;
Ok(string)
}
pub async fn read_format<Reader: AsyncReadExt + Unpin>(
reader: &mut Reader,
) -> Result<protocol::Format, ProtocolError> {
match reader.read_i16().await? {
0 => Ok(protocol::Format::Text),
1 => Ok(protocol::Format::Binary),
format_code => Err(protocol::ErrorResponse::error(
protocol::ErrorCode::ProtocolViolation,
format!("Unknown format code: {}", format_code),
)
.into()),
}
}
pub async fn write_direct<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
writer: &mut Writer,
message: Message,
) -> Result<(), ProtocolError> {
match message.serialize() {
Some(buffer) => {
writer.write_all(&buffer).await?;
writer.flush().await?;
}
_ => {}
};
Ok(())
}
fn message_serialize<Message: Serialize>(
message: Message,
packet_buffer: &mut BytesMut,
) -> Result<(), ProtocolError> {
if message.code() != 0x00 {
packet_buffer.put_u8(message.code());
}
match message.serialize() {
Some(buffer) => {
let size = u32::try_from(buffer.len() + 4).map_err(|_| {
ErrorResponse::error(
ErrorCode::InternalError,
"Unable to convert buffer length to a suitable memory size".to_string(),
)
})?;
packet_buffer.extend_from_slice(&size.to_be_bytes());
packet_buffer.extend_from_slice(&buffer);
}
_ => (),
};
Ok(())
}
pub async fn write_messages<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
writer: &mut Writer,
messages: Vec<Message>,
) -> Result<(), ProtocolError> {
let mut buffer = BytesMut::with_capacity(64 * messages.len());
for message in messages {
message_serialize(message, &mut buffer)?;
}
writer.write_all(&buffer).await?;
writer.flush().await?;
Ok(())
}
pub async fn write_message<Writer: AsyncWriteExt + Unpin, Message: Serialize>(
writer: &mut Writer,
message: Message,
) -> Result<(), ProtocolError> {
let mut buffer = BytesMut::with_capacity(64);
message_serialize(message, &mut buffer)?;
writer.write_all(&buffer).await?;
writer.flush().await?;
Ok(())
}
pub fn write_string(buffer: &mut Vec<u8>, string: &str) {
buffer.extend_from_slice(string.as_bytes());
buffer.push(0);
}