use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::io;
use tokio_util::codec::{Decoder, Encoder};
use super::{PostgresMessage, SQLMessage};
use crate::codec::constants::*;
use crate::codec::utils::*;
const BYTES_STARTUP_MESSAGE_HEADER: usize = 8;
const MESSAGE_ID_SSL_REQUEST: i32 = 80877103;
const MESSAGE_ID_STARTUP_MESSAGE: i32 = 196608;
const MESSAGE_ID_EXECUTE: u8 = b'E';
const MESSAGE_ID_FLUSH: u8 = b'H';
const MESSAGE_ID_QUERY: u8 = b'Q';
const MESSAGE_ID_SASL: u8 = b'p';
const MESSAGE_ID_SYNC: u8 = b'S';
const MESSAGE_ID_TERMINATE: u8 = b'X';
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Message {
NotImplemented(Bytes),
Canary(u8),
Bind {
portal: Bytes,
stmt_name: Bytes,
parameters: Vec<BindParameter>,
results_formats: Vec<u16>,
},
Execute {
portal: Bytes,
max_rows: u32,
},
Flush(),
Query(Bytes),
SASLInitialResponse {
mecanism: Bytes,
response: Bytes,
},
SASLResponse(Bytes),
SSLRequest(),
StartupMessage {
frame_length: usize,
parameters: Vec<Parameter>,
},
Sync(),
Terminate(),
CancelRequest(Bytes),
Close(Bytes),
CopyData(Bytes),
CopyDone(Bytes),
CopyFail(Bytes),
Describe(Bytes),
FunctionCall(Bytes),
GSSENCRequest(Bytes),
GSSResponse(Bytes),
Parse(Bytes),
PasswordMessage(Bytes),
}
impl PostgresMessage for Message {}
impl SQLMessage for Message {}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BindParameter {
pub format: u16,
pub value: Bytes,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Parameter {
pub name: Bytes,
pub value: Bytes,
}
#[derive(Debug, Clone)]
enum DecodeState {
Startup,
Head,
Message(usize),
}
#[derive(Debug, Clone)]
pub struct Codec {
state: DecodeState,
}
impl Codec {
#[must_use]
pub const fn new() -> Self {
Self {
state: DecodeState::Startup,
}
}
pub fn startup_complete(&mut self) {
self.state = DecodeState::Head;
}
fn decode_header(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
if src.len() < BYTES_MESSAGE_HEADER {
log::trace!(
"not enough header data ({} bytes), awaiting more ({} bytes)",
src.len(),
BYTES_MESSAGE_HEADER,
);
return Ok(None);
}
let mut buf = io::Cursor::new(&mut *src);
buf.advance(BYTES_MESSAGE_ID);
let frame_length = (buf.get_u32() as usize) + BYTES_MESSAGE_ID;
if frame_length < BYTES_MESSAGE_HEADER {
log::trace!("invalid frame: {:?}", buf);
let err = io::Error::new(
io::ErrorKind::InvalidInput,
"malformed packet - invalid message length",
);
log::error!("{}", err);
return Err(err);
}
Ok(Some(frame_length))
}
fn decode_message(&mut self, len: usize, src: &mut BytesMut) -> io::Result<Option<Message>> {
if src.len() < len {
log::trace!(
"not enough message data ({} bytes), awaiting more ({} bytes)",
src.len(),
len
);
return Ok(None);
}
let mut frame = src.split_to(len);
let msg_id = frame.get_u8();
log::trace!("incoming msg id: '{}' ({})", msg_id as char, msg_id);
let msg_length = (frame.get_u32() as usize) - BYTES_MESSAGE_SIZE;
log::trace!("incoming msg length: {}", msg_length);
let msg = match msg_id {
b'B' => {
frame.advance(msg_length);
Message::Canary(len as u8)
},
b'!' => {
return Err(io::Error::new(io::ErrorKind::InvalidData, "expected canary error"));
},
MESSAGE_ID_EXECUTE => {
let portal = get_cstr(&mut frame)?;
let max_rows = get_u32(&mut frame, "malformed packet - invalid execute data")?;
Message::Execute { portal, max_rows }
},
MESSAGE_ID_FLUSH => Message::Flush(),
MESSAGE_ID_QUERY => {
let query = get_cstr(&mut frame)?;
Message::Query(query)
},
MESSAGE_ID_SASL => {
if let Ok(mecanism) = get_cstr(&mut frame) {
const SASL_RESPONSE_SIZE_BYTES: usize = 4;
let response = get_bytes(
&mut frame,
SASL_RESPONSE_SIZE_BYTES,
"malformed packet - invalid SASL response data",
)?;
Message::SASLInitialResponse { mecanism, response }
} else {
let response = frame.copy_to_bytes(frame.remaining());
if response.is_empty() {
let err = std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"malformed packet - invalid SASL response data",
);
log::error!("{}", err);
return Err(err);
}
Message::SASLResponse(response)
}
},
MESSAGE_ID_SYNC => Message::Sync(),
MESSAGE_ID_TERMINATE => Message::Terminate(),
_ => {
let bytes = frame.copy_to_bytes(msg_length);
Message::NotImplemented(bytes)
},
};
if !frame.is_empty() {
log::trace!("invalid frame: {:?}", frame);
let err = std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"malformed packet - invalid message length",
);
log::error!("{}", err);
return Err(err);
}
log::debug!("decoded message frame: {:?}", msg);
Ok(Some(msg))
}
pub fn decode_startup_message(&mut self, src: &mut BytesMut) -> io::Result<Option<Message>> {
if src.len() < BYTES_STARTUP_MESSAGE_HEADER {
log::trace!(
"not enough header data ({} bytes), awaiting more ({} bytes)",
src.len(),
BYTES_STARTUP_MESSAGE_HEADER,
);
return Ok(None);
}
let mut buf = io::Cursor::new(&mut *src);
let frame_length = buf.get_u32() as usize;
if src.len() < frame_length {
log::trace!(
"not enough message data ({} bytes), awaiting more ({} bytes)",
src.len(),
frame_length,
);
return Ok(None);
}
let mut frame = src.split_to(frame_length);
log::trace!("decoded frame length: {}", frame_length);
frame.advance(4);
let msg_id = frame.get_i32();
log::trace!("msg id: {}", msg_id);
let msg = match msg_id {
MESSAGE_ID_STARTUP_MESSAGE => {
let mut parameters = Vec::new();
let mut user_param_exists = false;
while frame.remaining() > 2 {
let parameter_name = get_cstr(&mut frame)?;
if parameter_name == "user" {
user_param_exists = true;
}
let parameter = Parameter {
name: parameter_name,
value: get_cstr(&mut frame)?,
};
log::trace!("decoded parameter: {:?}", parameter);
parameters.push(parameter);
}
if frame.remaining() < 1 || !user_param_exists {
let err = std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"malformed packet - missing parameter fields",
);
log::error!("{}", err);
return Err(err);
}
frame.advance(1);
Message::StartupMessage {
frame_length,
parameters,
}
}
MESSAGE_ID_SSL_REQUEST => Message::SSLRequest(),
_ => {
let err = std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"malformed packet - invalid protocol version",
);
log::error!("{}", err);
return Err(err);
}
};
log::debug!("decoded message frame: {:?}", msg);
Ok(Some(msg))
}
fn encode_header(&mut self, msg_id: u8, msg_size: usize, dst: &mut BytesMut) {
dst.reserve(BYTES_MESSAGE_HEADER + msg_size);
dst.put_u8(msg_id);
dst.put_u32((BYTES_MESSAGE_SIZE + msg_size) as u32);
}
}
impl Decoder for Codec {
type Item = Message;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Self::Item>> {
log::trace!("decoder state: {:?}", self.state);
let msg_length = match self.state {
DecodeState::Startup => match self.decode_startup_message(src)? {
None => return Ok(None),
Some(Message::SSLRequest()) => return Ok(Some(Message::SSLRequest())),
Some(Message::StartupMessage {
frame_length,
parameters,
}) => {
self.startup_complete();
return Ok(Some(Message::StartupMessage {
frame_length,
parameters,
}));
}
Some(other) => {
let err = io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message during startup: {:?}", other),
);
log::error!("{}", err);
return Err(err);
}
},
DecodeState::Head => match self.decode_header(src)? {
None => return Ok(None),
Some(length) => {
self.state = DecodeState::Message(length);
src.reserve(length);
log::trace!("stream buffer capacity: {} bytes", src.capacity());
length
}
},
DecodeState::Message(length) => length,
};
log::trace!("decoded frame length: {} bytes", msg_length);
match self.decode_message(msg_length, src)? {
None => Ok(None),
Some(msg) => {
self.state = DecodeState::Head;
src.reserve(BYTES_MESSAGE_HEADER);
log::trace!("stream buffer capacity: {} bytes", src.capacity());
Ok(Some(msg))
}
}
}
}
impl Encoder<Message> for Codec {
type Error = io::Error;
fn encode(&mut self, msg: Message, dst: &mut BytesMut) -> Result<(), io::Error> {
match msg {
Message::Execute { portal, max_rows } => {
self.encode_header(MESSAGE_ID_EXECUTE, portal.len() + 1 + 4, dst);
put_cstr(&portal, dst);
dst.put_i32(max_rows as i32);
}
Message::Flush() => {
self.encode_header(MESSAGE_ID_FLUSH, 0, dst);
}
Message::Query(query) => {
self.encode_header(MESSAGE_ID_QUERY, query.len() + 1, dst);
put_cstr(&query, dst);
}
Message::SASLInitialResponse { mecanism, response } => {
self.encode_header(
MESSAGE_ID_SASL,
mecanism.len() + 1 + 4 + response.len(),
dst,
);
put_cstr(&mecanism, dst);
put_bytes(&response, dst);
}
Message::SASLResponse(response) => {
self.encode_header(MESSAGE_ID_SASL, response.len(), dst);
dst.put(response);
}
Message::StartupMessage {
frame_length,
parameters,
} => {
dst.reserve(frame_length);
dst.put_i32(frame_length as i32);
dst.put_i32(196608);
for parameter in ¶meters {
put_cstr(¶meter.name, dst);
put_cstr(¶meter.value, dst);
}
dst.put_u8(0); }
Message::SSLRequest() => {
dst.reserve(8);
dst.put_i32(8);
dst.put_i32(80877103);
}
Message::Sync() => {
self.encode_header(MESSAGE_ID_SYNC, 0, dst);
}
Message::Terminate() => {
self.encode_header(MESSAGE_ID_TERMINATE, 0, dst);
}
other => {
unimplemented!("not implemented: {:?}", other)
}
}
Ok(())
}
}
impl Default for Codec {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod decode_tests {
use bytes::{Bytes, BytesMut};
use test_log::test;
use super::{Codec, Message, Parameter};
fn assert_decode_startup_message(data: &[u8], expected: &[Message], remaining: usize) {
let buf = &mut BytesMut::from(data);
let mut decoded = Vec::new();
let mut codec = Codec::new();
while let Ok(Some(msg)) = codec.decode_startup_message(buf) {
decoded.push(msg);
}
assert_eq!(remaining, buf.len(), "remaining bytes in read buffer");
assert_eq!(expected.len(), decoded.len(), "decoded messages");
assert_eq!(expected, decoded, "decoded messages");
}
#[test]
#[rustfmt::skip]
fn valid_startup_message() {
let data = [
0, 0, 0, 78, 0, 3, 0, 0, 117, 115, 101, 114, 0, 114, 111, 111, 116, 0, 100, 97, 116, 97, 98, 97, 115, 101, 0, 116, 101, 115, 116, 100, 98, 0, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, 112, 115, 113, 108, 0, 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, 85, 84, 70, 56, 0, 0, ];
let expected = vec![
Message::StartupMessage {
frame_length: 78,
parameters: vec![
Parameter {
name: Bytes::from_static(b"user"),
value: Bytes::from_static(b"root"),
},
Parameter {
name: Bytes::from_static(b"database"),
value: Bytes::from_static(b"testdb"),
},
Parameter {
name: Bytes::from_static(b"application_name"),
value: Bytes::from_static(b"psql"),
},
Parameter {
name: Bytes::from_static(b"client_encoding"),
value: Bytes::from_static(b"UTF8"),
},
]},
];
let remaining = 0;
assert_decode_startup_message(&data[..], &expected, remaining);
}
#[test]
#[rustfmt::skip]
fn invalid_startup_message_wrong_protocol_version() {
let data = [
0, 0, 0, 78, 0, 2, 0, 0, 117, 115, 101, 114, 0, 114, 111, 111, 116, 0, 100, 97, 116, 97, 98, 97, 115, 101, 0, 116, 101, 115, 116, 100, 98, 0, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, 112, 115, 113, 108, 0, 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, 85, 84, 70, 56, 0, 0, ];
let expected = vec![];
let remaining = 0;
assert_decode_startup_message(&data[..], &expected, remaining);
}
#[test]
#[rustfmt::skip]
fn invalid_startup_message_missing_required_user() {
let data = [
0, 0, 0, 68, 0, 3, 0, 0, 100, 97, 116, 97, 98, 97, 115, 101, 0, 116, 101, 115, 116, 100, 98, 0, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, 112, 115, 113, 108, 0, 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, 85, 84, 70, 56, 0, 0, ];
let expected = vec![];
let remaining = 0;
assert_decode_startup_message(&data[..], &expected, remaining);
}
#[test]
#[rustfmt::skip]
fn invalid_startup_message_empty_parameters_list() {
let data = [
0, 0, 0, 9, 0, 3, 0, 0, 0, ];
let expected = vec![];
let remaining = 0;
assert_decode_startup_message(&data[..], &expected, remaining);
}
#[test]
#[rustfmt::skip]
fn invalid_startup_message_missing_parameters_data() {
let data = [
0, 0, 0, 8, 0, 3, 0, 0, ];
let expected = vec![];
let remaining = 0;
assert_decode_startup_message(&data[..], &expected, remaining);
}
#[test]
#[rustfmt::skip]
fn invalid_startup_message_missing_parameters_list_terminator() {
let data = [
0, 0, 0, 77, 0, 3, 0, 0, 117, 115, 101, 114, 0, 114, 111, 111, 116, 0, 100, 97, 116, 97, 98, 97, 115, 101, 0, 116, 101, 115, 116, 100, 98, 0, 97, 112, 112, 108, 105, 99, 97, 116, 105, 111, 110, 95, 110, 97, 109, 101, 0, 112, 115, 113, 108, 0, 99, 108, 105, 101, 110, 116, 95, 101, 110, 99, 111, 100, 105, 110, 103, 0, 85, 84, 70, 56, 0, ];
let expected = vec![];
let remaining = 0;
assert_decode_startup_message(&data[..], &expected, remaining);
}
#[test]
#[rustfmt::skip]
fn invalid_startup_message_missing_parameter_field() {
let data = [
0, 0, 0, 28, 0, 3, 0, 0, 117, 115, 101, 114, 0, 114, 111, 111, 116, 0, 100, 97, 116, 97, 98, 97, 115, 101, 0, 0, ];
let expected = vec![];
let remaining = 0;
assert_decode_startup_message(&data[..], &expected, remaining);
}
}