use crate::{cursor::Cursor, types::PgType, ConnId, ConnSecretKey, Error, PgFormat, Result, UnrecognizedFormat};
use std::convert::TryFrom;
const COMMAND_COMPLETE: u8 = b'C';
const DATA_ROW: u8 = b'D';
const ERROR_RESPONSE: u8 = b'E';
const SEVERITY: u8 = b'S';
const CODE: u8 = b'C';
const MESSAGE: u8 = b'M';
const EMPTY_QUERY_RESPONSE: u8 = b'I';
const NOTICE_RESPONSE: u8 = b'N';
const AUTHENTICATION: u8 = b'R';
const BACKEND_KEY_DATA: u8 = b'K';
const PARAMETER_STATUS: u8 = b'S';
const ROW_DESCRIPTION: u8 = b'T';
const READY_FOR_QUERY: u8 = b'Z';
const PARAMETER_DESCRIPTION: u8 = b't';
const NO_DATA: u8 = b'n';
const PARSE_COMPLETE: u8 = b'1';
const BIND_COMPLETE: u8 = b'2';
const CLOSE_COMPLETE: u8 = b'3';
pub(crate) const QUERY: u8 = b'Q';
const BIND: u8 = b'B';
const CLOSE: u8 = b'C';
const DESCRIBE: u8 = b'D';
const EXECUTE: u8 = b'E';
const FLUSH: u8 = b'H';
const PARSE: u8 = b'P';
const SYNC: u8 = b'S';
const TERMINATE: u8 = b'X';
#[derive(Debug, PartialEq)]
pub enum FrontendMessage {
GssencRequest,
SslRequest,
Setup {
params: Vec<(String, String)>,
},
Query {
sql: String,
},
Parse {
statement_name: String,
sql: String,
param_types: Vec<Option<PgType>>,
},
DescribeStatement {
name: String,
},
DescribePortal {
name: String,
},
Bind {
portal_name: String,
statement_name: String,
param_formats: Vec<PgFormat>,
raw_params: Vec<Option<Vec<u8>>>,
result_formats: Vec<PgFormat>,
},
Execute {
portal_name: String,
max_rows: i32,
},
Flush,
Sync,
CloseStatement {
name: String,
},
ClosePortal {
name: String,
},
Terminate,
}
impl FrontendMessage {
pub fn decode(tag: u8, buffer: &[u8]) -> Result<Self> {
log::trace!("Receives frontend tag = {:?}, buffer = {:?}", char::from(tag), buffer);
let cursor = Cursor::from(buffer);
match tag {
QUERY => decode_query(cursor),
BIND => decode_bind(cursor),
CLOSE => decode_close(cursor),
DESCRIBE => decode_describe(cursor),
EXECUTE => decode_execute(cursor),
FLUSH => decode_flush(cursor),
PARSE => decode_parse(cursor),
SYNC => decode_sync(cursor),
TERMINATE => decode_terminate(cursor),
_ => {
log::error!("unsupported frontend message tag {}", tag);
Err(Error::UnsupportedFrontendMessage)
}
}
}
}
#[allow(dead_code)]
#[derive(Debug, PartialEq)]
pub enum BackendMessage {
NoticeResponse,
AuthenticationCleartextPassword,
AuthenticationMD5Password,
AuthenticationOk,
BackendKeyData(ConnId, ConnSecretKey),
ReadyForQuery,
DataRow(Vec<String>),
RowDescription(Vec<ColumnMetadata>),
CommandComplete(String),
EmptyQueryResponse,
ErrorResponse(Option<&'static str>, Option<&'static str>, Option<String>),
ParameterStatus(String, String),
ParameterDescription(Vec<PgType>),
NoData,
ParseComplete,
BindComplete,
CloseComplete,
}
impl BackendMessage {
pub fn as_vec(&self) -> Vec<u8> {
match self {
BackendMessage::NoticeResponse => vec![NOTICE_RESPONSE],
BackendMessage::AuthenticationCleartextPassword => vec![AUTHENTICATION, 0, 0, 0, 8, 0, 0, 0, 3],
BackendMessage::AuthenticationMD5Password => vec![AUTHENTICATION, 0, 0, 0, 12, 0, 0, 0, 5, 1, 1, 1, 1],
BackendMessage::AuthenticationOk => vec![AUTHENTICATION, 0, 0, 0, 8, 0, 0, 0, 0],
BackendMessage::BackendKeyData(conn_id, secret_key) => {
let mut buff = vec![BACKEND_KEY_DATA, 0, 0, 0, 12];
buff.extend_from_slice(&conn_id.to_be_bytes());
buff.extend_from_slice(&secret_key.to_be_bytes());
buff
}
BackendMessage::ReadyForQuery => vec![READY_FOR_QUERY, 0, 0, 0, 5, EMPTY_QUERY_RESPONSE],
BackendMessage::DataRow(row) => {
let mut row_buff = Vec::new();
for field in row.iter() {
row_buff.extend_from_slice(&(field.len() as i32).to_be_bytes());
row_buff.extend_from_slice(field.as_str().as_bytes());
}
let mut len_buff = Vec::new();
len_buff.extend_from_slice(&[DATA_ROW]);
len_buff.extend_from_slice(&(6 + row_buff.len() as i32).to_be_bytes());
len_buff.extend_from_slice(&(row.len() as i16).to_be_bytes());
len_buff.extend_from_slice(&row_buff);
len_buff
}
BackendMessage::RowDescription(description) => {
let mut buff = Vec::new();
for field in description.iter() {
buff.extend_from_slice(field.name.as_str().as_bytes());
buff.extend_from_slice(&[0]);
buff.extend_from_slice(&(0i32).to_be_bytes());
buff.extend_from_slice(&(0i16).to_be_bytes());
buff.extend_from_slice(&field.type_id.to_be_bytes());
buff.extend_from_slice(&field.type_size.to_be_bytes());
buff.extend_from_slice(&(-1i32).to_be_bytes());
buff.extend_from_slice(&0i16.to_be_bytes());
}
let mut len_buff = Vec::new();
len_buff.extend_from_slice(&[ROW_DESCRIPTION]);
len_buff.extend_from_slice(&(6 + buff.len() as i32).to_be_bytes());
len_buff.extend_from_slice(&(description.len() as i16).to_be_bytes());
len_buff.extend_from_slice(&buff);
len_buff
}
BackendMessage::CommandComplete(command) => {
let mut command_buff = Vec::new();
command_buff.extend_from_slice(&[COMMAND_COMPLETE]);
command_buff.extend_from_slice(&(4 + command.len() as i32 + 1).to_be_bytes());
command_buff.extend_from_slice(command.as_bytes());
command_buff.extend_from_slice(&[0]);
command_buff
}
BackendMessage::EmptyQueryResponse => vec![EMPTY_QUERY_RESPONSE, 0, 0, 0, 4],
BackendMessage::ErrorResponse(severity, code, message) => {
let mut error_response_buff = Vec::new();
error_response_buff.extend_from_slice(&[ERROR_RESPONSE]);
let mut message_buff = Vec::new();
if let Some(severity) = severity.as_ref() {
message_buff.extend_from_slice(&[SEVERITY]);
message_buff.extend_from_slice(severity.as_bytes());
message_buff.extend_from_slice(&[0]);
}
if let Some(code) = code.as_ref() {
message_buff.extend_from_slice(&[CODE]);
message_buff.extend_from_slice(code.as_bytes());
message_buff.extend_from_slice(&[0]);
}
if let Some(message) = message.as_ref() {
message_buff.extend_from_slice(&[MESSAGE]);
message_buff.extend_from_slice(message.as_bytes());
message_buff.extend_from_slice(&[0]);
}
error_response_buff.extend_from_slice(&(message_buff.len() as i32 + 4 + 1).to_be_bytes());
error_response_buff.extend_from_slice(message_buff.as_ref());
error_response_buff.extend_from_slice(&[0]);
error_response_buff.to_vec()
}
BackendMessage::ParameterStatus(name, value) => {
let mut parameter_status_buff = Vec::new();
parameter_status_buff.extend_from_slice(&[PARAMETER_STATUS]);
let mut parameters = Vec::new();
parameters.extend_from_slice(name.as_bytes());
parameters.extend_from_slice(&[0]);
parameters.extend_from_slice(value.as_bytes());
parameters.extend_from_slice(&[0]);
parameter_status_buff.extend_from_slice(&(4 + parameters.len() as u32).to_be_bytes());
parameter_status_buff.extend_from_slice(parameters.as_ref());
parameter_status_buff
}
BackendMessage::ParameterDescription(pg_types) => {
let mut type_id_buff = Vec::new();
for pg_type in pg_types.iter() {
type_id_buff.extend_from_slice(&pg_type.type_oid().to_be_bytes());
}
let mut buff = Vec::new();
buff.extend_from_slice(&[PARAMETER_DESCRIPTION]);
buff.extend_from_slice(&(6 + type_id_buff.len() as i32).to_be_bytes());
buff.extend_from_slice(&(pg_types.len() as i16).to_be_bytes());
buff.extend_from_slice(&type_id_buff);
buff
}
BackendMessage::NoData => vec![NO_DATA, 0, 0, 0, 4],
BackendMessage::ParseComplete => vec![PARSE_COMPLETE, 0, 0, 0, 4],
BackendMessage::BindComplete => vec![BIND_COMPLETE, 0, 0, 0, 4],
BackendMessage::CloseComplete => vec![CLOSE_COMPLETE, 0, 0, 0, 4],
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ColumnMetadata {
pub name: String,
pub type_id: u32,
pub type_size: i16,
}
impl ColumnMetadata {
pub fn new<S: ToString>(name: S, pg_type: PgType) -> ColumnMetadata {
Self {
name: name.to_string(),
type_id: pg_type.type_oid(),
type_size: pg_type.type_len(),
}
}
}
fn decode_bind(mut cursor: Cursor) -> Result<FrontendMessage> {
let portal_name = cursor.read_cstr()?.to_owned();
let statement_name = cursor.read_cstr()?.to_owned();
let mut param_formats = vec![];
for _ in 0..cursor.read_i16()? {
match PgFormat::try_from(cursor.read_i16()?) {
Ok(format) => param_formats.push(format),
Err(UnrecognizedFormat(code)) => return Err(Error::InvalidInput(format!("unknown format code: {}", code))),
}
}
let mut raw_params = vec![];
for _ in 0..cursor.read_i16()? {
let len = cursor.read_i32()?;
if len == -1 {
raw_params.push(None);
} else {
let mut value = vec![];
for _ in 0..len {
value.push(cursor.read_byte()?);
}
raw_params.push(Some(value));
}
}
let mut result_formats = vec![];
for _ in 0..cursor.read_i16()? {
match PgFormat::try_from(cursor.read_i16()?) {
Ok(format) => result_formats.push(format),
Err(UnrecognizedFormat(code)) => return Err(Error::InvalidInput(format!("unknown format code: {}", code))),
}
}
Ok(FrontendMessage::Bind {
portal_name,
statement_name,
param_formats,
raw_params,
result_formats,
})
}
fn decode_close(mut cursor: Cursor) -> Result<FrontendMessage> {
let first_char = cursor.read_byte()?;
let name = cursor.read_cstr()?.to_owned();
match first_char {
b'P' => Ok(FrontendMessage::ClosePortal { name }),
b'S' => Ok(FrontendMessage::CloseStatement { name }),
other => Err(Error::InvalidInput(format!(
"invalid type byte in Close frontend message: {:?}",
std::char::from_u32(other as u32).unwrap(),
))),
}
}
fn decode_describe(mut cursor: Cursor) -> Result<FrontendMessage> {
let first_char = cursor.read_byte()?;
let name = cursor.read_cstr()?.to_owned();
match first_char {
b'P' => Ok(FrontendMessage::DescribePortal { name }),
b'S' => Ok(FrontendMessage::DescribeStatement { name }),
other => Err(Error::InvalidInput(format!(
"invalid type byte in Describe frontend message: {:?}",
char::from(other),
))),
}
}
fn decode_execute(mut cursor: Cursor) -> Result<FrontendMessage> {
let portal_name = cursor.read_cstr()?.to_owned();
let max_rows = cursor.read_i32()?;
Ok(FrontendMessage::Execute { portal_name, max_rows })
}
fn decode_flush(_cursor: Cursor) -> Result<FrontendMessage> {
Ok(FrontendMessage::Flush)
}
fn decode_parse(mut cursor: Cursor) -> Result<FrontendMessage> {
let statement_name = cursor.read_cstr()?.to_owned();
let sql = cursor.read_cstr()?.to_owned();
let mut param_types = vec![];
for _ in 0..cursor.read_i16()? {
let oid = PgType::from_oid(cursor.read_u32()?)?;
log::trace!("OID {:?}", oid);
param_types.push(oid);
}
Ok(FrontendMessage::Parse {
statement_name,
sql,
param_types,
})
}
fn decode_sync(_cursor: Cursor) -> Result<FrontendMessage> {
Ok(FrontendMessage::Sync)
}
fn decode_query(mut cursor: Cursor) -> Result<FrontendMessage> {
let sql = cursor.read_cstr()?.to_owned();
Ok(FrontendMessage::Query { sql })
}
fn decode_terminate(_cursor: Cursor) -> Result<FrontendMessage> {
Ok(FrontendMessage::Terminate)
}
#[cfg(test)]
mod decoding_frontend_messages {
use super::*;
#[test]
fn query() {
let buffer = [
99, 114, 101, 97, 116, 101, 32, 115, 99, 104, 101, 109, 97, 32, 115, 99, 104, 101, 109, 97, 95, 110, 97,
109, 101, 59, 0,
];
let message = FrontendMessage::decode(b'Q', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::Query {
sql: "create schema schema_name;".to_owned()
})
);
}
#[test]
fn bind() {
let buffer = [
112, 111, 114, 116, 97, 108, 95, 110, 97, 109, 101, 0, 115, 116, 97, 116, 101, 109, 101, 110, 116, 95, 110,
97, 109, 101, 0, 0, 2, 0, 1, 0, 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0,
];
let message = FrontendMessage::decode(b'B', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::Bind {
portal_name: "portal_name".to_owned(),
statement_name: "statement_name".to_owned(),
param_formats: vec![PgFormat::Binary, PgFormat::Binary],
raw_params: vec![Some(vec![0, 0, 0, 1]), Some(vec![0, 0, 0, 2])],
result_formats: vec![],
})
);
}
#[test]
fn close_portal() {
let buffer = [80, 112, 111, 114, 116, 97, 108, 95, 110, 97, 109, 101, 0];
let message = FrontendMessage::decode(b'C', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::ClosePortal {
name: "portal_name".to_owned(),
})
);
}
#[test]
fn close_statement() {
let buffer = [83, 115, 116, 97, 116, 101, 109, 101, 110, 116, 95, 110, 97, 109, 101, 0];
let message = FrontendMessage::decode(b'C', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::CloseStatement {
name: "statement_name".to_owned(),
})
);
}
#[test]
fn describe_portal() {
let buffer = [80, 112, 111, 114, 116, 97, 108, 95, 110, 97, 109, 101, 0];
let message = FrontendMessage::decode(b'D', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::DescribePortal {
name: "portal_name".to_owned()
})
);
}
#[test]
fn describe_statement() {
let buffer = [83, 115, 116, 97, 116, 101, 109, 101, 110, 116, 95, 110, 97, 109, 101, 0];
let message = FrontendMessage::decode(b'D', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::DescribeStatement {
name: "statement_name".to_owned()
})
);
}
#[test]
fn execute() {
let buffer = [112, 111, 114, 116, 97, 108, 95, 110, 97, 109, 101, 0, 0, 0, 0, 0];
let message = FrontendMessage::decode(b'E', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::Execute {
portal_name: "portal_name".to_owned(),
max_rows: 0,
})
);
}
#[test]
fn flush() {
let message = FrontendMessage::decode(b'H', &[]);
assert_eq!(message, Ok(FrontendMessage::Flush));
}
#[test]
fn parse() {
let buffer = [
0, 115, 101, 108, 101, 99, 116, 32, 42, 32, 102, 114, 111, 109, 32, 115, 99, 104, 101, 109, 97, 95, 110,
97, 109, 101, 46, 116, 97, 98, 108, 101, 95, 110, 97, 109, 101, 32, 119, 104, 101, 114, 101, 32, 115, 105,
95, 99, 111, 108, 117, 109, 110, 32, 61, 32, 36, 49, 59, 0, 0, 1, 0, 0, 0, 23,
];
let message = FrontendMessage::decode(b'P', &buffer);
assert_eq!(
message,
Ok(FrontendMessage::Parse {
statement_name: "".to_owned(),
sql: "select * from schema_name.table_name where si_column = $1;".to_owned(),
param_types: vec![Some(PgType::Integer)],
})
);
}
#[test]
fn sync() {
let message = FrontendMessage::decode(b'S', &[]);
assert_eq!(message, Ok(FrontendMessage::Sync));
}
#[test]
fn terminate() {
let message = FrontendMessage::decode(b'X', &[]);
assert_eq!(message, Ok(FrontendMessage::Terminate));
}
}
#[cfg(test)]
mod serializing_backend_messages {
use super::*;
#[test]
fn notice() {
assert_eq!(BackendMessage::NoticeResponse.as_vec(), vec![NOTICE_RESPONSE]);
}
#[test]
fn authentication_cleartext_password() {
assert_eq!(
BackendMessage::AuthenticationCleartextPassword.as_vec(),
vec![AUTHENTICATION, 0, 0, 0, 8, 0, 0, 0, 3]
)
}
#[test]
fn authentication_md5_password() {
assert_eq!(
BackendMessage::AuthenticationMD5Password.as_vec(),
vec![AUTHENTICATION, 0, 0, 0, 12, 0, 0, 0, 5, 1, 1, 1, 1]
)
}
#[test]
fn authentication_ok() {
assert_eq!(
BackendMessage::AuthenticationOk.as_vec(),
vec![AUTHENTICATION, 0, 0, 0, 8, 0, 0, 0, 0]
)
}
#[test]
fn backend_key_data() {
assert_eq!(
BackendMessage::BackendKeyData(1, 2).as_vec(),
vec![BACKEND_KEY_DATA, 0, 0, 0, 12, 0, 0, 0, 1, 0, 0, 0, 2]
)
}
#[test]
fn parameter_status() {
assert_eq!(
BackendMessage::ParameterStatus("client_encoding".to_owned(), "UTF8".to_owned()).as_vec(),
vec![
PARAMETER_STATUS,
0,
0,
0,
25,
99,
108,
105,
101,
110,
116,
95,
101,
110,
99,
111,
100,
105,
110,
103,
0,
85,
84,
70,
56,
0
]
)
}
#[test]
fn ready_for_query() {
assert_eq!(
BackendMessage::ReadyForQuery.as_vec(),
vec![READY_FOR_QUERY, 0, 0, 0, 5, EMPTY_QUERY_RESPONSE]
)
}
#[test]
fn data_row() {
assert_eq!(
BackendMessage::DataRow(vec!["1".to_owned(), "2".to_owned(), "3".to_owned()]).as_vec(),
vec![DATA_ROW, 0, 0, 0, 21, 0, 3, 0, 0, 0, 1, 49, 0, 0, 0, 1, 50, 0, 0, 0, 1, 51]
)
}
#[test]
fn row_description() {
assert_eq!(
BackendMessage::RowDescription(vec![ColumnMetadata::new("c1".to_owned(), PgType::Integer)]).as_vec(),
vec![
ROW_DESCRIPTION,
0,
0,
0,
27,
0,
1,
99,
49,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
23,
0,
4,
255,
255,
255,
255,
0,
0
]
);
}
#[test]
fn command_complete() {
assert_eq!(
BackendMessage::CommandComplete("SELECT".to_owned()).as_vec(),
vec![COMMAND_COMPLETE, 0, 0, 0, 11, 83, 69, 76, 69, 67, 84, 0]
)
}
#[test]
fn empty_response() {
assert_eq!(
BackendMessage::EmptyQueryResponse.as_vec(),
vec![EMPTY_QUERY_RESPONSE, 0, 0, 0, 4]
)
}
#[test]
fn error_response() {
assert_eq!(
BackendMessage::ErrorResponse(None, None, None).as_vec(),
vec![ERROR_RESPONSE, 0, 0, 0, 5, 0]
)
}
#[test]
fn parameter_description() {
assert_eq!(
BackendMessage::ParameterDescription(vec![PgType::Integer]).as_vec(),
vec![PARAMETER_DESCRIPTION, 0, 0, 0, 10, 0, 1, 0, 0, 0, 23]
)
}
#[test]
fn no_data() {
assert_eq!(BackendMessage::NoData.as_vec(), vec![NO_DATA, 0, 0, 0, 4])
}
#[test]
fn parse_complete() {
assert_eq!(BackendMessage::ParseComplete.as_vec(), vec![PARSE_COMPLETE, 0, 0, 0, 4])
}
#[test]
fn bind_complete() {
assert_eq!(BackendMessage::BindComplete.as_vec(), vec![BIND_COMPLETE, 0, 0, 0, 4])
}
#[test]
fn close_complete() {
assert_eq!(BackendMessage::CloseComplete.as_vec(), vec![CLOSE_COMPLETE, 0, 0, 0, 4])
}
}