use bytes::{Buf, Bytes};
use super::ProtocolError;
use crate::{common::ByteStr, ext::BytesExt};
pub trait BackendProtocol: Sized + std::fmt::Debug {
fn decode(msgtype: u8, body: Bytes) -> Result<Self, ProtocolError>;
}
pub enum BackendMessage {
Authentication(Authentication),
BackendKeyData(BackendKeyData),
BindComplete(BindComplete),
CloseComplete(CloseComplete),
CommandComplete(CommandComplete),
DataRow(DataRow),
ErrorResponse(ErrorResponse),
EmptyQueryResponse(EmptyQueryResponse),
NegotiateProtocolVersion(NegotiateProtocolVersion),
NoData(NoData),
NoticeResponse(NoticeResponse),
ParameterDescription(ParameterDescription),
ParameterStatus(ParameterStatus),
ParseComplete(ParseComplete),
PortalSuspended(PortalSuspended),
ReadyForQuery(ReadyForQuery),
RowDescription(RowDescription),
}
macro_rules! match_backend {
($($name:ident,)*) => {
impl BackendMessage {
pub const fn msgtype(&self) -> u8 {
match self {
$(Self::$name(_) => $name::MSGTYPE,)*
}
}
pub const fn message_name(msgtype: u8) -> &'static str {
match msgtype {
$($name::MSGTYPE => stringify!($name),)*
_ => "Unknown",
}
}
}
impl BackendProtocol for BackendMessage {
fn decode(msgtype: u8, body: Bytes) -> Result<Self, ProtocolError> {
let message = match msgtype {
$($name::MSGTYPE => Self::$name(<$name as BackendProtocol>::decode(msgtype, body)?),)*
_ => return Err(ProtocolError::unknown(msgtype)),
};
Ok(message)
}
}
impl std::fmt::Debug for BackendMessage {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
$(Self::$name(e) => std::fmt::Debug::fmt(e, f),)*
}
}
}
};
}
match_backend! {
Authentication,
BackendKeyData,
BindComplete,
CloseComplete,
CommandComplete,
DataRow,
ErrorResponse,
EmptyQueryResponse,
NegotiateProtocolVersion,
NoData,
NoticeResponse,
ParameterDescription,
ParameterStatus,
ParseComplete,
PortalSuspended,
ReadyForQuery,
RowDescription,
}
macro_rules! assert_msgtype {
($typ:ident) => {
if Self::MSGTYPE != $typ {
return Err(ProtocolError::unexpected(Self::MSGTYPE,$typ))
}
};
}
macro_rules! msgtype {
($me:ident,$ty:literal) => {
impl $me {
#[doc = concat!("`",stringify!($ty),"`")]
pub const MSGTYPE: u8 = $ty;
}
};
}
#[derive(Debug)]
pub enum Authentication {
Ok,
KerberosV5,
CleartextPassword,
MD5Password {
salt: [u8;4],
},
GSS,
GSSContinue {
data: Bytes,
},
SSPI,
SASL {
name: Bytes,
},
SASLContinue {
data: Bytes,
},
SASLFinal {
data: Bytes,
},
}
msgtype!(Authentication, b'R');
impl BackendProtocol for Authentication {
fn decode(msgtype: u8, mut body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
let auth = match body.get_u32() {
0 => Self::Ok,
2 => Self::KerberosV5,
3 => Self::CleartextPassword,
5 => Self::MD5Password { salt: body.get_u32().to_be_bytes(), },
7 => Self::GSS,
8 => Self::GSSContinue { data: body },
9 => Self::SSPI,
10 => Self::SASL { name: body },
11 => Self::SASLContinue { data: body },
12 => Self::SASLFinal { data: body },
auth => panic!("Unknown Authentication type: \"{auth}\""),
};
Ok(auth)
}
}
#[derive(Clone, Copy)]
pub struct BackendKeyData {
pub process_id: u32,
pub secret_key: u32,
}
msgtype!(BackendKeyData, b'K');
impl BackendProtocol for BackendKeyData {
fn decode(msgtype: u8, mut body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self {
process_id: body.get_u32(),
secret_key: body.get_u32(),
})
}
}
#[derive(Debug)]
pub struct ParameterStatus {
pub name: ByteStr,
pub value: ByteStr,
}
msgtype!(ParameterStatus, b'S');
impl BackendProtocol for ParameterStatus {
fn decode(msgtype: u8, mut body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self {
name: body.get_nul_bytestr()?,
value: body.get_nul_bytestr()?,
})
}
}
pub struct NoticeResponse {
pub body: Bytes
}
msgtype!(NoticeResponse, b'N');
impl NoticeResponse {
pub fn new(body: Bytes) -> Self {
Self { body }
}
}
impl BackendProtocol for NoticeResponse {
fn decode(msgtype: u8, body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self { body })
}
}
pub struct ErrorResponse {
pub body: Bytes,
}
msgtype!(ErrorResponse, b'E');
impl ErrorResponse {
pub fn new(body: Bytes) -> Self {
Self { body }
}
}
impl BackendProtocol for ErrorResponse {
fn decode(msgtype: u8, body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self { body })
}
}
pub struct RowDescription {
pub body: Bytes,
}
msgtype!(RowDescription, b'T');
impl BackendProtocol for RowDescription {
fn decode(msgtype: u8, body: Bytes) -> Result<Self, ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self { body })
}
}
pub struct DataRow {
pub body: Bytes,
}
msgtype!(DataRow, b'D');
impl BackendProtocol for DataRow {
fn decode(msgtype: u8, body: Bytes) -> Result<Self, ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self { body })
}
}
#[derive(Debug)]
pub struct CommandComplete {
pub tag: ByteStr,
}
msgtype!(CommandComplete, b'C');
impl BackendProtocol for CommandComplete {
fn decode(msgtype: u8, mut body: Bytes) -> Result<Self, ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self {
tag: body.get_nul_bytestr()?,
})
}
}
#[derive(Debug)]
pub struct NegotiateProtocolVersion {
pub minor: u32,
pub len: u32,
pub opt_names: Bytes,
}
msgtype!(NegotiateProtocolVersion, b'v');
impl BackendProtocol for NegotiateProtocolVersion {
fn decode(msgtype: u8, mut body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self {
minor: body.get_u32(),
len: body.get_u32(),
opt_names: body,
})
}
}
#[derive(Debug)]
pub struct ParameterDescription {
pub param_len: u16,
pub oids: Bytes,
}
msgtype!(ParameterDescription, b't');
impl BackendProtocol for ParameterDescription {
fn decode(msgtype: u8, mut body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self {
param_len: body.get_u16(),
oids: body,
})
}
}
pub struct ReadyForQuery {
pub tx_status: u8
}
msgtype!(ReadyForQuery, b'Z');
impl BackendProtocol for ReadyForQuery {
fn decode(msgtype: u8, mut body: Bytes) -> Result<Self,ProtocolError> {
assert_msgtype!(msgtype);
Ok(Self { tx_status: body.get_u8() })
}
}
macro_rules! unit_msg {
($(
$(#[$doc:meta])* struct $name:ident, $ty:literal;
)*) => {$(
$(#[$doc])*
#[derive(Debug)]
pub struct $name;
msgtype!($name, $ty);
impl BackendProtocol for $name {
fn decode(msgtype: u8, _: Bytes) -> Result<Self,ProtocolError> {
if $name::MSGTYPE != msgtype {
return Err(ProtocolError::unexpected(Self::MSGTYPE,msgtype))
}
Ok(Self)
}
}
)*};
}
unit_msg! {
struct BindComplete, b'2';
struct CloseComplete, b'3';
struct EmptyQueryResponse, b'I';
struct NoData, b'n';
struct ParseComplete, b'1';
struct PortalSuspended, b's';
}
impl std::fmt::Debug for BackendKeyData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BackendKeyData")
.field("process_id", &self.process_id)
.field("secret_key", &"<REDACTED>")
.finish()
}
}
impl std::fmt::Debug for ReadyForQuery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReadyForQuery")
.field("tx_status", &match self.tx_status {
b'I' => "Idle(I)",
b'T' => "Transaction(T)",
b'E' => "FailedTx(E)",
_ => "unknown",
})
.finish()
}
}
impl std::fmt::Debug for RowDescription {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RowDescription")
.field("body", &"<BINARY>")
.finish()
}
}
impl std::fmt::Debug for DataRow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DataRow")
.field("body", &"<BINARY>")
.finish()
}
}