use zerocopy::{FromBytes, Immutable, KnownLayout};
use crate::error::{Error, Result};
use crate::protocol::codec::{read_cstr, read_i32, read_u32};
use crate::protocol::types::TransactionStatus;
pub mod auth_type {
pub const OK: i32 = 0;
pub const KERBEROS_V5: i32 = 2;
pub const CLEARTEXT_PASSWORD: i32 = 3;
pub const MD5_PASSWORD: i32 = 5;
pub const GSS: i32 = 7;
pub const GSS_CONTINUE: i32 = 8;
pub const SSPI: i32 = 9;
pub const SASL: i32 = 10;
pub const SASL_CONTINUE: i32 = 11;
pub const SASL_FINAL: i32 = 12;
}
#[derive(Debug)]
pub enum AuthenticationMessage<'a> {
Ok,
KerberosV5,
CleartextPassword,
Md5Password { salt: [u8; 4] },
Gss,
GssContinue { data: &'a [u8] },
Sspi,
Sasl { mechanisms: Vec<&'a str> },
SaslContinue { data: &'a [u8] },
SaslFinal { data: &'a [u8] },
}
impl<'a> AuthenticationMessage<'a> {
pub fn parse(payload: &'a [u8]) -> Result<Self> {
let (auth_type, rest) = read_i32(payload)?;
match auth_type {
auth_type::OK => Ok(AuthenticationMessage::Ok),
auth_type::KERBEROS_V5 => Ok(AuthenticationMessage::KerberosV5),
auth_type::CLEARTEXT_PASSWORD => Ok(AuthenticationMessage::CleartextPassword),
auth_type::MD5_PASSWORD => {
if rest.len() < 4 {
return Err(Error::LibraryBug("MD5Password: missing salt".into()));
}
let mut salt = [0u8; 4];
salt.copy_from_slice(&rest[..4]);
Ok(AuthenticationMessage::Md5Password { salt })
}
auth_type::GSS => Ok(AuthenticationMessage::Gss),
auth_type::GSS_CONTINUE => Ok(AuthenticationMessage::GssContinue { data: rest }),
auth_type::SSPI => Ok(AuthenticationMessage::Sspi),
auth_type::SASL => {
let mut mechanisms = Vec::new();
let mut data = rest;
while !data.is_empty() && data[0] != 0 {
let (mechanism, remaining) = read_cstr(data)?;
mechanisms.push(mechanism);
data = remaining;
}
Ok(AuthenticationMessage::Sasl { mechanisms })
}
auth_type::SASL_CONTINUE => Ok(AuthenticationMessage::SaslContinue { data: rest }),
auth_type::SASL_FINAL => Ok(AuthenticationMessage::SaslFinal { data: rest }),
_ => Err(Error::LibraryBug(format!(
"Unknown authentication type: {}",
auth_type
))),
}
}
}
#[derive(Debug, Clone)]
pub struct BackendKeyData {
pid: u32,
secret_key: Vec<u8>,
}
impl BackendKeyData {
pub fn parse(payload: &[u8]) -> Result<Self> {
if payload.len() < 4 {
return Err(Error::LibraryBug(
"BackendKeyData: payload too short".into(),
));
}
let (pid, rest) = read_u32(payload)?;
if rest.len() < 4 || rest.len() > 256 {
return Err(Error::LibraryBug(format!(
"BackendKeyData: invalid secret key length {}",
rest.len()
)));
}
Ok(Self {
pid,
secret_key: rest.to_vec(),
})
}
pub fn process_id(&self) -> u32 {
self.pid
}
pub fn secret_key(&self) -> &[u8] {
&self.secret_key
}
}
#[derive(Debug, Clone)]
pub struct ParameterStatus<'a> {
pub name: &'a str,
pub value: &'a str,
}
impl<'a> ParameterStatus<'a> {
pub fn parse(payload: &'a [u8]) -> Result<Self> {
let (name, rest) = read_cstr(payload)?;
let (value, _) = read_cstr(rest)?;
Ok(Self { name, value })
}
}
#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
#[repr(C, packed)]
pub struct ReadyForQuery {
pub status: u8,
}
impl ReadyForQuery {
pub fn parse(payload: &[u8]) -> Result<&Self> {
Self::ref_from_bytes(payload)
.map_err(|e| Error::LibraryBug(format!("ReadyForQuery: {e:?}")))
}
pub fn transaction_status(&self) -> Option<TransactionStatus> {
TransactionStatus::from_byte(self.status)
}
}
#[derive(Debug, Clone)]
pub struct NotificationResponse<'a> {
pub pid: u32,
pub channel: &'a str,
pub payload: &'a str,
}
impl<'a> NotificationResponse<'a> {
pub fn parse(payload: &'a [u8]) -> Result<Self> {
let (pid, rest) = read_u32(payload)?;
let (channel, rest) = read_cstr(rest)?;
let (payload_str, _) = read_cstr(rest)?;
Ok(Self {
pid,
channel,
payload: payload_str,
})
}
}
#[derive(Debug, Clone)]
pub struct NegotiateProtocolVersion<'a> {
pub newest_minor_version: u32,
pub unrecognized_options: Vec<&'a str>,
}
impl<'a> NegotiateProtocolVersion<'a> {
pub fn parse(payload: &'a [u8]) -> Result<Self> {
let (newest_minor_version, rest) = read_u32(payload)?;
let (num_options, mut rest) = read_u32(rest)?;
let mut unrecognized_options = Vec::with_capacity(num_options as usize);
for _ in 0..num_options {
let (option, remaining) = read_cstr(rest)?;
unrecognized_options.push(option);
rest = remaining;
}
Ok(Self {
newest_minor_version,
unrecognized_options,
})
}
}