use super::EncodeError;
use bytes::BytesMut;
pub struct PgEncoder;
impl PgEncoder {
pub const FORMAT_TEXT: i16 = 0;
pub const FORMAT_BINARY: i16 = 1;
#[inline(always)]
fn result_format_wire_len(result_format: i16) -> usize {
if result_format == Self::FORMAT_TEXT {
2 } else {
4 }
}
#[inline(always)]
fn encode_result_formats_vec(content: &mut Vec<u8>, result_format: i16) {
if result_format == Self::FORMAT_TEXT {
content.extend_from_slice(&0i16.to_be_bytes());
} else {
content.extend_from_slice(&1i16.to_be_bytes());
content.extend_from_slice(&result_format.to_be_bytes());
}
}
#[inline(always)]
fn encode_result_formats_bytesmut(buf: &mut BytesMut, result_format: i16) {
if result_format == Self::FORMAT_TEXT {
buf.extend_from_slice(&0i16.to_be_bytes());
} else {
buf.extend_from_slice(&1i16.to_be_bytes());
buf.extend_from_slice(&result_format.to_be_bytes());
}
}
#[inline(always)]
fn content_len_to_wire_len(content_len: usize) -> Result<i32, EncodeError> {
let total = content_len
.checked_add(4)
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
i32::try_from(total).map_err(|_| EncodeError::MessageTooLarge(total))
}
#[inline(always)]
fn usize_to_i16(n: usize) -> Result<i16, EncodeError> {
i16::try_from(n).map_err(|_| EncodeError::TooManyParameters(n))
}
#[inline(always)]
fn usize_to_i32(n: usize) -> Result<i32, EncodeError> {
i32::try_from(n).map_err(|_| EncodeError::MessageTooLarge(n))
}
#[inline(always)]
fn has_nul(s: &str) -> bool {
s.as_bytes().contains(&0)
}
pub fn try_encode_query_string(sql: &str) -> Result<BytesMut, EncodeError> {
if Self::has_nul(sql) {
return Err(EncodeError::NullByte);
}
let mut buf = BytesMut::new();
let content_len = sql.len() + 1; let total_len = Self::content_len_to_wire_len(content_len)?;
buf.extend_from_slice(b"Q");
buf.extend_from_slice(&total_len.to_be_bytes());
buf.extend_from_slice(sql.as_bytes());
buf.extend_from_slice(&[0]);
Ok(buf)
}
pub fn encode_terminate() -> BytesMut {
let mut buf = BytesMut::new();
buf.extend_from_slice(&[b'X', 0, 0, 0, 4]);
buf
}
pub fn encode_sync() -> BytesMut {
let mut buf = BytesMut::new();
buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
buf
}
pub fn try_encode_parse(
name: &str,
sql: &str,
param_types: &[u32],
) -> Result<BytesMut, EncodeError> {
if Self::has_nul(name) || Self::has_nul(sql) {
return Err(EncodeError::NullByte);
}
if param_types.len() > i16::MAX as usize {
return Err(EncodeError::TooManyParameters(param_types.len()));
}
let mut buf = BytesMut::new();
buf.extend_from_slice(b"P");
let mut content = Vec::new();
content.extend_from_slice(name.as_bytes());
content.push(0);
content.extend_from_slice(sql.as_bytes());
content.push(0);
let param_count = Self::usize_to_i16(param_types.len())?;
content.extend_from_slice(¶m_count.to_be_bytes());
for &oid in param_types {
content.extend_from_slice(&oid.to_be_bytes());
}
let len = Self::content_len_to_wire_len(content.len())?;
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&content);
Ok(buf)
}
pub fn encode_bind(
portal: &str,
statement: &str,
params: &[Option<Vec<u8>>],
) -> Result<BytesMut, EncodeError> {
Self::encode_bind_with_result_format(portal, statement, params, Self::FORMAT_TEXT)
}
pub fn encode_bind_with_result_format(
portal: &str,
statement: &str,
params: &[Option<Vec<u8>>],
result_format: i16,
) -> Result<BytesMut, EncodeError> {
if Self::has_nul(portal) || Self::has_nul(statement) {
return Err(EncodeError::NullByte);
}
if params.len() > i16::MAX as usize {
return Err(EncodeError::TooManyParameters(params.len()));
}
let mut buf = BytesMut::new();
buf.extend_from_slice(b"B");
let mut content = Vec::new();
content.extend_from_slice(portal.as_bytes());
content.push(0);
content.extend_from_slice(statement.as_bytes());
content.push(0);
content.extend_from_slice(&0i16.to_be_bytes());
let param_count = Self::usize_to_i16(params.len())?;
content.extend_from_slice(¶m_count.to_be_bytes());
for param in params {
match param {
None => {
content.extend_from_slice(&(-1i32).to_be_bytes());
}
Some(data) => {
let data_len = Self::usize_to_i32(data.len())?;
content.extend_from_slice(&data_len.to_be_bytes());
content.extend_from_slice(data);
}
}
}
Self::encode_result_formats_vec(&mut content, result_format);
let len = Self::content_len_to_wire_len(content.len())?;
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&content);
Ok(buf)
}
pub fn try_encode_execute(portal: &str, max_rows: i32) -> Result<BytesMut, EncodeError> {
if Self::has_nul(portal) {
return Err(EncodeError::NullByte);
}
if max_rows < 0 {
return Err(EncodeError::InvalidMaxRows(max_rows));
}
let mut buf = BytesMut::new();
buf.extend_from_slice(b"E");
let mut content = Vec::new();
content.extend_from_slice(portal.as_bytes());
content.push(0);
content.extend_from_slice(&max_rows.to_be_bytes());
let len = Self::content_len_to_wire_len(content.len())?;
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&content);
Ok(buf)
}
pub fn try_encode_describe(is_portal: bool, name: &str) -> Result<BytesMut, EncodeError> {
if Self::has_nul(name) {
return Err(EncodeError::NullByte);
}
let mut buf = BytesMut::new();
buf.extend_from_slice(b"D");
let mut content = Vec::new();
content.push(if is_portal { b'P' } else { b'S' });
content.extend_from_slice(name.as_bytes());
content.push(0);
let len = Self::content_len_to_wire_len(content.len())?;
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&content);
Ok(buf)
}
pub fn encode_extended_query(
sql: &str,
params: &[Option<Vec<u8>>],
) -> Result<BytesMut, EncodeError> {
Self::encode_extended_query_with_result_format(sql, params, Self::FORMAT_TEXT)
}
pub fn encode_extended_query_with_result_format(
sql: &str,
params: &[Option<Vec<u8>>],
result_format: i16,
) -> Result<BytesMut, EncodeError> {
if Self::has_nul(sql) {
return Err(EncodeError::NullByte);
}
if params.len() > i16::MAX as usize {
return Err(EncodeError::TooManyParameters(params.len()));
}
let params_size = params.iter().try_fold(0usize, |acc, p| {
let field_size = 4usize
.checked_add(p.as_ref().map_or(0usize, |v| v.len()))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
acc.checked_add(field_size)
.ok_or(EncodeError::MessageTooLarge(usize::MAX))
})?;
let result_formats_size = Self::result_format_wire_len(result_format);
let total_size = 9usize
.checked_add(sql.len())
.and_then(|v| v.checked_add(11))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.and_then(|v| v.checked_add(10))
.and_then(|v| v.checked_add(5))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let mut buf = BytesMut::with_capacity(total_size);
buf.extend_from_slice(b"P");
let parse_content_len = 1usize
.checked_add(sql.len())
.and_then(|v| v.checked_add(1))
.and_then(|v| v.checked_add(2))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let parse_len = Self::content_len_to_wire_len(parse_content_len)?;
buf.extend_from_slice(&parse_len.to_be_bytes());
buf.extend_from_slice(&[0]); buf.extend_from_slice(sql.as_bytes());
buf.extend_from_slice(&[0]); buf.extend_from_slice(&0i16.to_be_bytes());
buf.extend_from_slice(b"B");
let bind_content_len = 1usize
.checked_add(1)
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let bind_len = Self::content_len_to_wire_len(bind_content_len)?;
buf.extend_from_slice(&bind_len.to_be_bytes());
buf.extend_from_slice(&[0]); buf.extend_from_slice(&[0]); buf.extend_from_slice(&0i16.to_be_bytes()); let param_count = Self::usize_to_i16(params.len())?;
buf.extend_from_slice(¶m_count.to_be_bytes());
for param in params {
match param {
None => buf.extend_from_slice(&(-1i32).to_be_bytes()),
Some(data) => {
let data_len = Self::usize_to_i32(data.len())?;
buf.extend_from_slice(&data_len.to_be_bytes());
buf.extend_from_slice(data);
}
}
}
Self::encode_result_formats_bytesmut(&mut buf, result_format);
buf.extend_from_slice(b"E");
buf.extend_from_slice(&9i32.to_be_bytes()); buf.extend_from_slice(&[0]); buf.extend_from_slice(&0i32.to_be_bytes());
buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
Ok(buf)
}
pub fn try_encode_copy_fail(reason: &str) -> Result<BytesMut, EncodeError> {
if Self::has_nul(reason) {
return Err(EncodeError::NullByte);
}
let mut buf = BytesMut::new();
buf.extend_from_slice(b"f");
let content_len = reason.len() + 1; let len = Self::content_len_to_wire_len(content_len)?;
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(reason.as_bytes());
buf.extend_from_slice(&[0]);
Ok(buf)
}
pub fn try_encode_close(is_portal: bool, name: &str) -> Result<BytesMut, EncodeError> {
if Self::has_nul(name) {
return Err(EncodeError::NullByte);
}
let mut buf = BytesMut::new();
buf.extend_from_slice(b"C");
let content_len = 1 + name.len() + 1; let len = Self::content_len_to_wire_len(content_len)?;
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&[if is_portal { b'P' } else { b'S' }]);
buf.extend_from_slice(name.as_bytes());
buf.extend_from_slice(&[0]);
Ok(buf)
}
}
use bytes::BufMut;
pub enum Param<'a> {
Null,
Bytes(&'a [u8]),
}
impl PgEncoder {
#[inline(always)]
fn put_i32_be(buf: &mut BytesMut, v: i32) {
buf.put_i32(v);
}
#[inline(always)]
fn put_i16_be(buf: &mut BytesMut, v: i16) {
buf.put_i16(v);
}
#[inline]
pub fn encode_bind_ultra<'a>(
buf: &mut BytesMut,
statement: &str,
params: &[Param<'a>],
) -> Result<(), EncodeError> {
Self::encode_bind_ultra_with_result_format(buf, statement, params, Self::FORMAT_TEXT)
}
#[inline]
pub fn encode_bind_ultra_with_result_format<'a>(
buf: &mut BytesMut,
statement: &str,
params: &[Param<'a>],
result_format: i16,
) -> Result<(), EncodeError> {
if Self::has_nul(statement) {
return Err(EncodeError::NullByte);
}
if params.len() > i16::MAX as usize {
return Err(EncodeError::TooManyParameters(params.len()));
}
let params_size = params.iter().try_fold(0usize, |acc, p| {
let field_size = match p {
Param::Null => 4usize,
Param::Bytes(b) => 4usize
.checked_add(b.len())
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?,
};
acc.checked_add(field_size)
.ok_or(EncodeError::MessageTooLarge(usize::MAX))
})?;
let result_formats_size = Self::result_format_wire_len(result_format);
let content_len = 1usize
.checked_add(statement.len())
.and_then(|v| v.checked_add(1))
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let wire_len = Self::content_len_to_wire_len(content_len)?;
buf.reserve(1 + 4 + content_len);
buf.put_u8(b'B');
Self::put_i32_be(buf, wire_len);
buf.put_u8(0);
buf.extend_from_slice(statement.as_bytes());
buf.put_u8(0);
Self::put_i16_be(buf, 0);
let param_count = Self::usize_to_i16(params.len())?;
Self::put_i16_be(buf, param_count);
for param in params {
match param {
Param::Null => Self::put_i32_be(buf, -1),
Param::Bytes(data) => {
let data_len = Self::usize_to_i32(data.len())?;
Self::put_i32_be(buf, data_len);
buf.extend_from_slice(data);
}
}
}
Self::encode_result_formats_bytesmut(buf, result_format);
Ok(())
}
#[inline(always)]
pub fn encode_execute_ultra(buf: &mut BytesMut) {
buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
}
#[inline(always)]
pub fn encode_sync_ultra(buf: &mut BytesMut) {
buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
}
#[inline]
pub fn encode_bind_to(
buf: &mut BytesMut,
statement: &str,
params: &[Option<Vec<u8>>],
) -> Result<(), EncodeError> {
Self::encode_bind_to_with_result_format(buf, statement, params, Self::FORMAT_TEXT)
}
#[inline]
pub fn encode_bind_to_with_result_format(
buf: &mut BytesMut,
statement: &str,
params: &[Option<Vec<u8>>],
result_format: i16,
) -> Result<(), EncodeError> {
if Self::has_nul(statement) {
return Err(EncodeError::NullByte);
}
if params.len() > i16::MAX as usize {
return Err(EncodeError::TooManyParameters(params.len()));
}
let params_size = params.iter().try_fold(0usize, |acc, p| {
let field_size = 4usize
.checked_add(p.as_ref().map_or(0usize, |v| v.len()))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
acc.checked_add(field_size)
.ok_or(EncodeError::MessageTooLarge(usize::MAX))
})?;
let result_formats_size = Self::result_format_wire_len(result_format);
let content_len = 1usize
.checked_add(statement.len())
.and_then(|v| v.checked_add(1))
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(2))
.and_then(|v| v.checked_add(params_size))
.and_then(|v| v.checked_add(result_formats_size))
.ok_or(EncodeError::MessageTooLarge(usize::MAX))?;
let wire_len = Self::content_len_to_wire_len(content_len)?;
buf.reserve(1 + 4 + content_len);
buf.put_u8(b'B');
Self::put_i32_be(buf, wire_len);
buf.put_u8(0);
buf.extend_from_slice(statement.as_bytes());
buf.put_u8(0);
Self::put_i16_be(buf, 0);
let param_count = Self::usize_to_i16(params.len())?;
Self::put_i16_be(buf, param_count);
for param in params {
match param {
None => Self::put_i32_be(buf, -1),
Some(data) => {
let data_len = Self::usize_to_i32(data.len())?;
Self::put_i32_be(buf, data_len);
buf.extend_from_slice(data);
}
}
}
Self::encode_result_formats_bytesmut(buf, result_format);
Ok(())
}
#[inline]
pub fn encode_execute_to(buf: &mut BytesMut) {
buf.extend_from_slice(&[b'E', 0, 0, 0, 9, 0, 0, 0, 0, 0]);
}
#[inline]
pub fn encode_sync_to(buf: &mut BytesMut) {
buf.extend_from_slice(&[b'S', 0, 0, 0, 4]);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_query_string() {
let sql = "SELECT 1";
let bytes = PgEncoder::try_encode_query_string(sql).unwrap();
assert_eq!(bytes[0], b'Q');
let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
assert_eq!(len, 13);
assert_eq!(&bytes[5..13], b"SELECT 1");
assert_eq!(bytes[13], 0);
}
#[test]
fn test_encode_terminate() {
let bytes = PgEncoder::encode_terminate();
assert_eq!(bytes.as_ref(), &[b'X', 0, 0, 0, 4]);
}
#[test]
fn test_encode_sync() {
let bytes = PgEncoder::encode_sync();
assert_eq!(bytes.as_ref(), &[b'S', 0, 0, 0, 4]);
}
#[test]
fn test_encode_parse() {
let bytes = PgEncoder::try_encode_parse("", "SELECT $1", &[]).unwrap();
assert_eq!(bytes[0], b'P');
let content = String::from_utf8_lossy(&bytes[5..]);
assert!(content.contains("SELECT $1"));
}
#[test]
fn test_encode_bind() {
let params = vec![
Some(b"42".to_vec()),
None, ];
let bytes = PgEncoder::encode_bind("", "", ¶ms).unwrap();
assert_eq!(bytes[0], b'B');
let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
assert!(len > 4); }
#[test]
fn test_encode_bind_binary_result_format() {
let bytes =
PgEncoder::encode_bind_with_result_format("", "", &[], PgEncoder::FORMAT_BINARY)
.unwrap();
assert_eq!(&bytes[11..15], &[0, 1, 0, 1]);
}
#[test]
fn test_encode_execute() {
let bytes = PgEncoder::try_encode_execute("", 0).unwrap();
assert_eq!(bytes[0], b'E');
let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
assert_eq!(len, 9);
}
#[test]
fn test_encode_execute_negative_max_rows_returns_error() {
let err = PgEncoder::try_encode_execute("", -1).expect_err("must reject negative max_rows");
assert_eq!(err, EncodeError::InvalidMaxRows(-1));
}
#[test]
fn test_encode_extended_query() {
let params = vec![Some(b"hello".to_vec())];
let bytes = PgEncoder::encode_extended_query("SELECT $1", ¶ms).unwrap();
assert!(bytes.windows(1).any(|w| w == [b'P']));
assert!(bytes.windows(1).any(|w| w == [b'B']));
assert!(bytes.windows(1).any(|w| w == [b'E']));
assert!(bytes.windows(1).any(|w| w == [b'S']));
}
#[test]
fn test_encode_extended_query_binary_result_format() {
let bytes = PgEncoder::encode_extended_query_with_result_format(
"SELECT 1",
&[],
PgEncoder::FORMAT_BINARY,
)
.unwrap();
let parse_len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
let bind_start = 1 + parse_len;
assert_eq!(bytes[bind_start], b'B');
let bind_len = i32::from_be_bytes([
bytes[bind_start + 1],
bytes[bind_start + 2],
bytes[bind_start + 3],
bytes[bind_start + 4],
]);
assert_eq!(bind_len, 14);
let bind_content = &bytes[bind_start + 5..bind_start + 1 + bind_len as usize];
assert_eq!(&bind_content[6..10], &[0, 1, 0, 1]);
}
#[test]
fn test_encode_copy_fail() {
let bytes = PgEncoder::try_encode_copy_fail("bad data").unwrap();
assert_eq!(bytes[0], b'f');
let len = i32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
assert_eq!(len as usize, 4 + "bad data".len() + 1);
assert_eq!(&bytes[5..13], b"bad data");
assert_eq!(bytes[13], 0);
}
#[test]
fn test_encode_close_statement() {
let bytes = PgEncoder::try_encode_close(false, "my_stmt").unwrap();
assert_eq!(bytes[0], b'C');
assert_eq!(bytes[5], b'S'); assert_eq!(&bytes[6..13], b"my_stmt");
assert_eq!(bytes[13], 0);
}
#[test]
fn test_encode_close_portal() {
let bytes = PgEncoder::try_encode_close(true, "").unwrap();
assert_eq!(bytes[0], b'C');
assert_eq!(bytes[5], b'P'); assert_eq!(bytes[6], 0); }
#[test]
fn test_encode_parse_too_many_param_types_returns_error() {
let param_types = vec![0u32; (i16::MAX as usize) + 1];
let err =
PgEncoder::try_encode_parse("s", "SELECT 1", ¶m_types).expect_err("must reject");
assert_eq!(err, EncodeError::TooManyParameters(param_types.len()));
}
#[test]
fn test_encode_bind_to_binary_result_format() {
let mut buf = BytesMut::new();
PgEncoder::encode_bind_to_with_result_format(&mut buf, "", &[], PgEncoder::FORMAT_BINARY)
.unwrap();
assert_eq!(&buf[11..15], &[0, 1, 0, 1]);
}
#[test]
fn test_encode_bind_ultra_binary_result_format() {
let mut buf = BytesMut::new();
PgEncoder::encode_bind_ultra_with_result_format(
&mut buf,
"",
&[],
PgEncoder::FORMAT_BINARY,
)
.unwrap();
assert_eq!(&buf[11..15], &[0, 1, 0, 1]);
}
#[test]
fn test_encode_query_string_with_nul_returns_empty() {
let err =
PgEncoder::try_encode_query_string("select 1\0select 2").expect_err("must reject NUL");
assert_eq!(err, EncodeError::NullByte);
}
#[test]
fn test_encode_parse_with_nul_returns_empty() {
let err = PgEncoder::try_encode_parse("s", "SELECT 1\0", &[]).expect_err("must reject");
assert_eq!(err, EncodeError::NullByte);
}
#[test]
fn test_encode_bind_with_nul_rejected() {
let err = PgEncoder::encode_bind_with_result_format("\0", "", &[], PgEncoder::FORMAT_TEXT)
.expect_err("bind with NUL portal must fail");
assert_eq!(err, EncodeError::NullByte);
}
#[test]
fn test_encode_extended_query_with_nul_rejected() {
let err = PgEncoder::encode_extended_query_with_result_format(
"SELECT 1\0UNION SELECT 2",
&[],
PgEncoder::FORMAT_TEXT,
)
.expect_err("extended query with NUL SQL must fail");
assert_eq!(err, EncodeError::NullByte);
}
#[test]
fn test_encode_copy_fail_with_nul_returns_empty() {
let err = PgEncoder::try_encode_copy_fail("bad\0data").expect_err("must reject");
assert_eq!(err, EncodeError::NullByte);
}
}