use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub const PG_PROTOCOL_V3: u32 = 3 << 16;
pub const PG_SSL_REQUEST: u32 = 80877103;
pub const PG_GSSENC_REQUEST: u32 = 80877104;
pub const PG_CANCEL_REQUEST: u32 = 80877102;
#[derive(Debug)]
pub enum PgWireError {
Io(io::Error),
Protocol(String),
Eof,
}
impl From<io::Error> for PgWireError {
fn from(err: io::Error) -> Self {
if err.kind() == io::ErrorKind::UnexpectedEof {
PgWireError::Eof
} else {
PgWireError::Io(err)
}
}
}
impl std::fmt::Display for PgWireError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PgWireError::Io(e) => write!(f, "pg wire io: {e}"),
PgWireError::Protocol(m) => write!(f, "pg wire protocol: {m}"),
PgWireError::Eof => write!(f, "pg wire eof"),
}
}
}
impl std::error::Error for PgWireError {}
#[derive(Debug, Clone)]
pub enum FrontendMessage {
Startup(StartupParams),
SslRequest,
GssEncRequest,
Query(String),
PasswordMessage(Vec<u8>),
Terminate,
Flush,
Sync,
Unknown { tag: u8, payload: Vec<u8> },
}
#[derive(Debug, Clone, Default)]
pub struct StartupParams {
pub params: Vec<(String, String)>,
}
impl StartupParams {
pub fn get(&self, key: &str) -> Option<&str> {
self.params
.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| v.as_str())
}
}
#[derive(Debug, Clone)]
pub enum BackendMessage {
AuthenticationOk,
ParameterStatus { name: String, value: String },
BackendKeyData { pid: u32, key: u32 },
ReadyForQuery(TransactionStatus),
RowDescription(Vec<ColumnDescriptor>),
DataRow(Vec<Option<Vec<u8>>>),
CommandComplete(String),
ErrorResponse {
severity: String,
code: String,
message: String,
},
NoticeResponse { message: String },
EmptyQueryResponse,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionStatus {
Idle,
InTransaction,
Failed,
}
impl TransactionStatus {
pub fn as_byte(self) -> u8 {
match self {
TransactionStatus::Idle => b'I',
TransactionStatus::InTransaction => b'T',
TransactionStatus::Failed => b'E',
}
}
}
#[derive(Debug, Clone)]
pub struct ColumnDescriptor {
pub name: String,
pub table_oid: u32,
pub column_attr: i16,
pub type_oid: u32,
pub type_size: i16,
pub type_mod: i32,
pub format: i16,
}
pub async fn read_startup<R: AsyncRead + Unpin>(
stream: &mut R,
) -> Result<FrontendMessage, PgWireError> {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf);
if !(8..=65536).contains(&len) {
return Err(PgWireError::Protocol(format!(
"startup length {len} out of range"
)));
}
let body_len = (len as usize) - 4;
let mut body = vec![0u8; body_len];
stream.read_exact(&mut body).await?;
if body_len < 4 {
return Err(PgWireError::Protocol("startup payload too short".into()));
}
let version = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
match version {
PG_SSL_REQUEST => Ok(FrontendMessage::SslRequest),
PG_GSSENC_REQUEST => Ok(FrontendMessage::GssEncRequest),
PG_PROTOCOL_V3 => {
let mut params: Vec<(String, String)> = Vec::new();
let mut pos = 4usize;
while pos < body_len {
if body[pos] == 0 {
break;
}
let key = read_cstring(&body, &mut pos)?;
if pos >= body_len {
return Err(PgWireError::Protocol(
"startup parameter missing value".into(),
));
}
let value = read_cstring(&body, &mut pos)?;
params.push((key, value));
}
Ok(FrontendMessage::Startup(StartupParams { params }))
}
PG_CANCEL_REQUEST => Ok(FrontendMessage::Unknown {
tag: b'K',
payload: body,
}),
_ => Err(PgWireError::Protocol(format!(
"unsupported protocol version {version}"
))),
}
}
pub async fn read_frame<R: AsyncRead + Unpin>(
stream: &mut R,
) -> Result<FrontendMessage, PgWireError> {
let mut tag_buf = [0u8; 1];
match stream.read_exact(&mut tag_buf).await {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Err(PgWireError::Eof),
Err(e) => return Err(PgWireError::Io(e)),
}
let tag = tag_buf[0];
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf);
if !(4..=1_048_576).contains(&len) {
return Err(PgWireError::Protocol(format!(
"frame length {len} out of bounds"
)));
}
let payload_len = (len as usize) - 4;
let mut payload = vec![0u8; payload_len];
stream.read_exact(&mut payload).await?;
Ok(match tag {
b'Q' => {
let mut pos = 0;
let query = read_cstring(&payload, &mut pos)?;
FrontendMessage::Query(query)
}
b'p' => FrontendMessage::PasswordMessage(payload),
b'X' => FrontendMessage::Terminate,
b'H' => FrontendMessage::Flush,
b'S' => FrontendMessage::Sync,
other => FrontendMessage::Unknown {
tag: other,
payload,
},
})
}
pub async fn write_raw_byte<W: AsyncWrite + Unpin>(
stream: &mut W,
byte: u8,
) -> Result<(), PgWireError> {
stream.write_all(&[byte]).await?;
Ok(())
}
pub async fn write_frame<W: AsyncWrite + Unpin>(
stream: &mut W,
msg: &BackendMessage,
) -> Result<(), PgWireError> {
let (tag, payload) = encode_backend(msg);
let length = (payload.len() + 4) as u32;
stream.write_all(&[tag]).await?;
stream.write_all(&length.to_be_bytes()).await?;
stream.write_all(&payload).await?;
Ok(())
}
fn sanitize_cstring_bytes(input: &[u8]) -> Vec<u8> {
if !input.contains(&0) {
return input.to_vec();
}
let mut out = Vec::with_capacity(input.len() + 8);
for &b in input {
if b == 0 {
out.extend_from_slice(&[0xEF, 0xBF, 0xBD]);
} else {
out.push(b);
}
}
out
}
#[inline]
fn push_cstring(buf: &mut Vec<u8>, value: &str) {
buf.extend_from_slice(&sanitize_cstring_bytes(value.as_bytes()));
buf.push(0);
}
fn encode_backend(msg: &BackendMessage) -> (u8, Vec<u8>) {
match msg {
BackendMessage::AuthenticationOk => {
(b'R', vec![0, 0, 0, 0])
}
BackendMessage::ParameterStatus { name, value } => {
let mut buf = Vec::with_capacity(name.len() + value.len() + 2);
push_cstring(&mut buf, name);
push_cstring(&mut buf, value);
(b'S', buf)
}
BackendMessage::BackendKeyData { pid, key } => {
let mut buf = Vec::with_capacity(8);
buf.extend_from_slice(&pid.to_be_bytes());
buf.extend_from_slice(&key.to_be_bytes());
(b'K', buf)
}
BackendMessage::ReadyForQuery(status) => (b'Z', vec![status.as_byte()]),
BackendMessage::RowDescription(cols) => {
let mut buf = Vec::new();
buf.extend_from_slice(&(cols.len() as i16).to_be_bytes());
for col in cols {
push_cstring(&mut buf, &col.name);
buf.extend_from_slice(&col.table_oid.to_be_bytes());
buf.extend_from_slice(&col.column_attr.to_be_bytes());
buf.extend_from_slice(&col.type_oid.to_be_bytes());
buf.extend_from_slice(&col.type_size.to_be_bytes());
buf.extend_from_slice(&col.type_mod.to_be_bytes());
buf.extend_from_slice(&col.format.to_be_bytes());
}
(b'T', buf)
}
BackendMessage::DataRow(fields) => {
let mut buf = Vec::new();
buf.extend_from_slice(&(fields.len() as i16).to_be_bytes());
for field in fields {
match field {
None => {
buf.extend_from_slice(&(-1i32).to_be_bytes());
}
Some(bytes) => {
buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
buf.extend_from_slice(bytes);
}
}
}
(b'D', buf)
}
BackendMessage::CommandComplete(tag) => {
let mut buf = Vec::with_capacity(tag.len() + 1);
push_cstring(&mut buf, tag);
(b'C', buf)
}
BackendMessage::ErrorResponse {
severity,
code,
message,
} => {
let mut buf = Vec::new();
buf.push(b'S');
push_cstring(&mut buf, severity);
buf.push(b'V');
push_cstring(&mut buf, severity);
buf.push(b'C');
push_cstring(&mut buf, code);
buf.push(b'M');
push_cstring(&mut buf, message);
buf.push(0);
(b'E', buf)
}
BackendMessage::NoticeResponse { message } => {
let mut buf = Vec::new();
buf.push(b'S');
buf.extend_from_slice(b"NOTICE");
buf.push(0);
buf.push(b'M');
push_cstring(&mut buf, message);
buf.push(0);
(b'N', buf)
}
BackendMessage::EmptyQueryResponse => (b'I', Vec::new()),
}
}
fn read_cstring(buf: &[u8], pos: &mut usize) -> Result<String, PgWireError> {
let start = *pos;
while *pos < buf.len() && buf[*pos] != 0 {
*pos += 1;
}
if *pos >= buf.len() {
return Err(PgWireError::Protocol("cstring missing terminator".into()));
}
let s = std::str::from_utf8(&buf[start..*pos])
.map_err(|e| PgWireError::Protocol(format!("invalid utf8: {e}")))?
.to_string();
*pos += 1; Ok(s)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn parse_startup_v3() {
let mut payload: Vec<u8> = Vec::new();
payload.extend_from_slice(&PG_PROTOCOL_V3.to_be_bytes());
payload.extend_from_slice(b"user\0alice\0");
payload.push(0);
let len = (4 + payload.len()) as u32;
let mut frame = Vec::new();
frame.extend_from_slice(&len.to_be_bytes());
frame.extend_from_slice(&payload);
let mut cursor = std::io::Cursor::new(frame);
let msg = read_startup(&mut cursor).await.unwrap();
match msg {
FrontendMessage::Startup(params) => {
assert_eq!(params.get("user"), Some("alice"));
}
other => panic!("expected Startup, got {:?}", other),
}
}
#[tokio::test]
async fn parse_ssl_request() {
let mut frame: Vec<u8> = Vec::new();
frame.extend_from_slice(&8u32.to_be_bytes());
frame.extend_from_slice(&PG_SSL_REQUEST.to_be_bytes());
let mut cursor = std::io::Cursor::new(frame);
assert!(matches!(
read_startup(&mut cursor).await.unwrap(),
FrontendMessage::SslRequest
));
}
#[tokio::test]
async fn parse_query_frame() {
let query = "SELECT 1\0";
let mut frame = Vec::new();
frame.push(b'Q');
let len = (4 + query.len()) as u32;
frame.extend_from_slice(&len.to_be_bytes());
frame.extend_from_slice(query.as_bytes());
let mut cursor = std::io::Cursor::new(frame);
match read_frame(&mut cursor).await.unwrap() {
FrontendMessage::Query(s) => assert_eq!(s, "SELECT 1"),
other => panic!("expected Query, got {:?}", other),
}
}
#[tokio::test]
async fn emit_ready_for_query() {
let mut out: Vec<u8> = Vec::new();
write_frame(
&mut out,
&BackendMessage::ReadyForQuery(TransactionStatus::Idle),
)
.await
.unwrap();
assert_eq!(out, vec![b'Z', 0, 0, 0, 5, b'I']);
}
#[tokio::test]
async fn emit_row_description_and_data_row() {
let mut out: Vec<u8> = Vec::new();
write_frame(
&mut out,
&BackendMessage::RowDescription(vec![ColumnDescriptor {
name: "id".to_string(),
table_oid: 0,
column_attr: 0,
type_oid: 23,
type_size: 4,
type_mod: -1,
format: 0,
}]),
)
.await
.unwrap();
assert_eq!(out[0], b'T');
let mut data: Vec<u8> = Vec::new();
write_frame(
&mut data,
&BackendMessage::DataRow(vec![Some(b"42".to_vec()), None]),
)
.await
.unwrap();
assert_eq!(data[0], b'D');
}
fn count_nul(buf: &[u8]) -> usize {
buf.iter().filter(|&&b| b == 0).count()
}
#[tokio::test]
async fn pg3_nul_error_response_message_field_sanitized() {
let mut out: Vec<u8> = Vec::new();
write_frame(
&mut out,
&BackendMessage::ErrorResponse {
severity: "ERROR".to_string(),
code: "42000".to_string(),
message: "smuggled\0M\x00injection".to_string(),
},
)
.await
.unwrap();
assert_eq!(out[0], b'E');
let body = &out[5..];
assert_eq!(
count_nul(body),
5,
"expected 5 NULs (4 field + 1 list-end), got {} :: body={:?}",
count_nul(body),
body
);
assert!(
body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]),
"expected U+FFFD substitution in body"
);
}
#[tokio::test]
async fn pg3_nul_notice_response_sanitized() {
let mut out: Vec<u8> = Vec::new();
write_frame(
&mut out,
&BackendMessage::NoticeResponse {
message: "evil\0field".to_string(),
},
)
.await
.unwrap();
assert_eq!(out[0], b'N');
let body = &out[5..];
assert_eq!(count_nul(body), 3);
assert!(body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]));
}
#[tokio::test]
async fn pg3_nul_command_complete_sanitized() {
let mut out: Vec<u8> = Vec::new();
write_frame(
&mut out,
&BackendMessage::CommandComplete("SELECT\0;DROP".to_string()),
)
.await
.unwrap();
assert_eq!(out[0], b'C');
let body = &out[5..];
assert_eq!(count_nul(body), 1);
}
#[tokio::test]
async fn pg3_nul_row_description_column_name_sanitized() {
let mut out: Vec<u8> = Vec::new();
write_frame(
&mut out,
&BackendMessage::RowDescription(vec![ColumnDescriptor {
name: "evil\0col".to_string(),
table_oid: 0,
column_attr: 0,
type_oid: 23,
type_size: 4,
type_mod: -1,
format: 0,
}]),
)
.await
.unwrap();
assert_eq!(out[0], b'T');
let body = &out[5..];
let name_region = &body[2..];
let first_nul = name_region.iter().position(|&b| b == 0).unwrap();
assert!(
name_region[..first_nul]
.windows(3)
.any(|w| w == [0xEF, 0xBF, 0xBD]),
"U+FFFD missing from sanitized column name"
);
}
#[test]
fn sanitize_cstring_fastpath_no_nul() {
let s = "no nuls here";
let out = sanitize_cstring_bytes(s.as_bytes());
assert_eq!(out, s.as_bytes());
}
#[test]
fn sanitize_cstring_substitutes_nul_with_replacement_codepoint() {
let s = b"a\0b\0c";
let out = sanitize_cstring_bytes(s);
assert_eq!(out.len(), 9);
assert!(!out.contains(&0));
assert_eq!(&out[1..4], &[0xEF, 0xBF, 0xBD]);
assert_eq!(&out[5..8], &[0xEF, 0xBF, 0xBD]);
}
}