use std::collections::BTreeMap;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use super::{DecodeContext, Message, ProtocolVersion, codec};
use crate::error::{PgWireError, PgWireResult};
pub(crate) const MINIMUM_STARTUP_MESSAGE_LEN: usize = 8;
pub(crate) const MAXIMUM_STARTUP_MESSAGE_LEN: usize = super::SMALL_PACKET_SIZE_LIMIT;
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct Startup {
#[new(value = "3")]
pub protocol_number_major: u16,
#[new(value = "0")]
pub protocol_number_minor: u16,
#[new(default)]
pub parameters: BTreeMap<String, String>,
}
impl Default for Startup {
fn default() -> Startup {
Startup::new()
}
}
impl Startup {
pub const PROTOCOL_VERSION_3_0: i32 = 196608;
pub const PROTOCOL_VERSION_3_2: i32 = 196610;
pub const PG_PROTOCOL_EARLIEST: u16 = 3;
pub const PG_PROTOCOL_LATEST: u16 = 3;
}
impl Message for Startup {
fn message_length(&self) -> usize {
let param_length = self
.parameters
.iter()
.map(|(k, v)| k.len() + v.len() + 2)
.sum::<usize>();
9 + param_length
}
#[inline]
fn max_message_length() -> usize {
MAXIMUM_STARTUP_MESSAGE_LEN
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
buf.put_u16(self.protocol_number_major);
buf.put_u16(self.protocol_number_minor);
for (k, v) in self.parameters.iter() {
codec::put_cstring(buf, k);
codec::put_cstring(buf, v);
}
codec::put_cstring(buf, "");
Ok(())
}
fn decode(buf: &mut BytesMut, ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
codec::decode_packet(buf, 0, Self::max_message_length(), |buf, full_len| {
Self::decode_body(buf, full_len, ctx)
})
}
fn decode_body(buf: &mut BytesMut, msg_len: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
if msg_len <= MINIMUM_STARTUP_MESSAGE_LEN {
return Err(PgWireError::InvalidStartupMessage);
}
let protocol_number_major = buf.get_u16();
if !(Self::PG_PROTOCOL_EARLIEST..=Self::PG_PROTOCOL_LATEST).contains(&protocol_number_major)
{
return Err(PgWireError::InvalidStartupMessage);
}
let protocol_number_minor = buf.get_u16();
let mut parameters = BTreeMap::new();
while let Some(key) = codec::get_cstring(buf) {
let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
parameters.insert(key, value);
}
Ok(Startup {
protocol_number_major,
protocol_number_minor,
parameters,
})
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug)]
pub enum Authentication {
Ok, CleartextPassword, KerberosV5, MD5Password(Vec<u8>),
SASL(Vec<String>), SASLContinue(Bytes), SASLFinal(Bytes),
}
pub const MESSAGE_TYPE_BYTE_AUTHENTICATION: u8 = b'R';
impl Message for Authentication {
#[inline]
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_AUTHENTICATION)
}
#[inline]
fn max_message_length() -> usize {
super::SMALL_BACKEND_PACKET_SIZE_LIMIT
}
#[inline]
fn message_length(&self) -> usize {
match self {
Authentication::Ok | Authentication::CleartextPassword | Authentication::KerberosV5 => {
8
}
Authentication::MD5Password(_) => 12,
Authentication::SASL(methods) => {
8 + methods.iter().map(|v| v.len() + 1).sum::<usize>() + 1
}
Authentication::SASLContinue(data) => 8 + data.len(),
Authentication::SASLFinal(data) => 8 + data.len(),
}
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
match self {
Authentication::Ok => buf.put_i32(0),
Authentication::CleartextPassword => buf.put_i32(3),
Authentication::KerberosV5 => buf.put_i32(2),
Authentication::MD5Password(salt) => {
buf.put_i32(5);
buf.put_slice(salt.as_ref());
}
Authentication::SASL(methods) => {
buf.put_i32(10);
for method in methods {
codec::put_cstring(buf, method);
}
buf.put_u8(b'\0');
}
Authentication::SASLContinue(data) => {
buf.put_i32(11);
buf.put_slice(data.as_ref());
}
Authentication::SASLFinal(data) => {
buf.put_i32(12);
buf.put_slice(data.as_ref());
}
}
Ok(())
}
fn decode_body(buf: &mut BytesMut, msg_len: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
let code = buf.get_i32();
let msg = match code {
0 => Authentication::Ok,
2 => Authentication::KerberosV5,
3 => Authentication::CleartextPassword,
5 => {
let mut salt_vec = vec![0; 4];
buf.copy_to_slice(&mut salt_vec);
Authentication::MD5Password(salt_vec)
}
10 => {
let mut methods = Vec::new();
while let Some(method) = codec::get_cstring(buf) {
methods.push(method);
}
Authentication::SASL(methods)
}
11 => Authentication::SASLContinue(buf.split_to(msg_len - 8).freeze()),
12 => Authentication::SASLFinal(buf.split_to(msg_len - 8).freeze()),
_ => {
return Err(PgWireError::InvalidAuthenticationMessageCode(code));
}
};
Ok(msg)
}
}
pub const MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY: u8 = b'p';
#[non_exhaustive]
#[derive(Debug)]
pub enum PasswordMessageFamily {
Raw(BytesMut),
Password(Password),
SASLInitialResponse(SASLInitialResponse),
SASLResponse(SASLResponse),
}
impl Message for PasswordMessageFamily {
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
}
fn message_length(&self) -> usize {
match self {
PasswordMessageFamily::Raw(body) => body.len() + 4,
PasswordMessageFamily::Password(inner) => inner.message_length(),
PasswordMessageFamily::SASLInitialResponse(inner) => inner.message_length(),
PasswordMessageFamily::SASLResponse(inner) => inner.message_length(),
}
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
match self {
PasswordMessageFamily::Raw(body) => {
buf.put_slice(body.as_ref());
Ok(())
}
PasswordMessageFamily::Password(inner) => inner.encode_body(buf),
PasswordMessageFamily::SASLInitialResponse(inner) => inner.encode_body(buf),
PasswordMessageFamily::SASLResponse(inner) => inner.encode_body(buf),
}
}
fn decode_body(
buf: &mut BytesMut,
full_len: usize,
_ctx: &DecodeContext,
) -> PgWireResult<Self> {
let body = buf.split_to(full_len - 4);
Ok(PasswordMessageFamily::Raw(body))
}
}
impl PasswordMessageFamily {
pub fn into_password(self) -> PgWireResult<Password> {
match self {
PasswordMessageFamily::Raw(mut body) => {
let len = body.len() + 4;
Password::decode_body(&mut body, len, &DecodeContext::default())
}
PasswordMessageFamily::Password(pass) => Ok(pass),
_ => Err(PgWireError::FailedToCoercePasswordMessage),
}
}
pub fn into_sasl_initial_response(self) -> PgWireResult<SASLInitialResponse> {
match self {
PasswordMessageFamily::Raw(mut body) => {
let len = body.len() + 4;
SASLInitialResponse::decode_body(&mut body, len, &DecodeContext::default())
}
PasswordMessageFamily::SASLInitialResponse(msg) => Ok(msg),
_ => Err(PgWireError::FailedToCoercePasswordMessage),
}
}
pub fn into_sasl_response(self) -> PgWireResult<SASLResponse> {
match self {
PasswordMessageFamily::Raw(mut body) => {
let len = body.len() + 4;
SASLResponse::decode_body(&mut body, len, &DecodeContext::default())
}
PasswordMessageFamily::SASLResponse(msg) => Ok(msg),
_ => Err(PgWireError::FailedToCoercePasswordMessage),
}
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct Password {
pub password: String,
}
impl Message for Password {
#[inline]
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
}
fn message_length(&self) -> usize {
5 + self.password.len()
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
codec::put_cstring(buf, &self.password);
Ok(())
}
fn decode_body(buf: &mut BytesMut, _: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
let pass = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
Ok(Password::new(pass))
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct ParameterStatus {
pub name: String,
pub value: String,
}
pub const MESSAGE_TYPE_BYTE_PARAMETER_STATUS: u8 = b'S';
impl Message for ParameterStatus {
#[inline]
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_PARAMETER_STATUS)
}
#[inline]
fn max_message_length() -> usize {
super::SMALL_BACKEND_PACKET_SIZE_LIMIT
}
fn message_length(&self) -> usize {
4 + 2 + self.name.len() + self.value.len()
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
codec::put_cstring(buf, &self.name);
codec::put_cstring(buf, &self.value);
Ok(())
}
fn decode_body(buf: &mut BytesMut, _: usize, _ctx: &DecodeContext) -> PgWireResult<Self> {
let name = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
let value = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
Ok(ParameterStatus::new(name, value))
}
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub enum SecretKey {
I32(i32),
Bytes(Bytes),
}
impl Default for SecretKey {
fn default() -> Self {
SecretKey::I32(0)
}
}
impl SecretKey {
pub fn as_i32(&self) -> Option<i32> {
match self {
Self::I32(v) => Some(*v),
Self::Bytes(v) => {
if v.len() == 4 {
Some((&v[..]).get_i32())
} else {
None
}
}
}
}
fn validate(&self) -> PgWireResult<()> {
match self {
SecretKey::I32(_) => Ok(()),
SecretKey::Bytes(key_bytes) => {
let len = key_bytes.len();
Self::validate_bytes_len(len)
}
}
}
fn validate_bytes_len(data_len: usize) -> PgWireResult<()> {
if !(4..=256).contains(&data_len) {
return Err(PgWireError::InvalidSecretKey);
}
Ok(())
}
pub(crate) fn len(&self) -> usize {
match self {
SecretKey::I32(_) => 4,
SecretKey::Bytes(key_bytes) => key_bytes.len(),
}
}
pub fn encode(&self, buf: &mut BytesMut) -> PgWireResult<()> {
match self {
SecretKey::I32(key) => buf.put_i32(*key),
SecretKey::Bytes(key) => {
self.validate()?;
buf.put_slice(key)
}
}
Ok(())
}
pub fn decode(buf: &mut BytesMut, data_len: usize, ctx: &DecodeContext) -> PgWireResult<Self> {
Self::validate_bytes_len(data_len)?;
match ctx.protocol_version {
ProtocolVersion::PROTOCOL3_2 => Ok(SecretKey::Bytes(buf.split_to(data_len).freeze())),
ProtocolVersion::PROTOCOL3_0 => Ok(SecretKey::I32(buf.get_i32())),
}
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct BackendKeyData {
pub pid: i32,
pub secret_key: SecretKey,
}
pub const MESSAGE_TYPE_BYTE_BACKEND_KEY_DATA: u8 = b'K';
impl Message for BackendKeyData {
#[inline]
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_BACKEND_KEY_DATA)
}
#[inline]
fn max_message_length() -> usize {
super::SMALL_BACKEND_PACKET_SIZE_LIMIT
}
#[inline]
fn message_length(&self) -> usize {
8 + self.secret_key.len()
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
buf.put_i32(self.pid);
self.secret_key.encode(buf)?;
Ok(())
}
fn decode_body(buf: &mut BytesMut, msg_len: usize, ctx: &DecodeContext) -> PgWireResult<Self> {
let pid = buf.get_i32();
let secret_key = SecretKey::decode(buf, msg_len - 8, ctx)?;
Ok(BackendKeyData { pid, secret_key })
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct SslRequest;
impl SslRequest {
pub const BODY_MAGIC_NUMBER: i32 = 80877103;
pub const BODY_SIZE: usize = MINIMUM_STARTUP_MESSAGE_LEN;
pub fn is_ssl_request_packet(buf: &[u8]) -> bool {
if buf.remaining() >= Self::BODY_SIZE {
let magic_code = (&buf[4..8]).get_i32();
magic_code == Self::BODY_MAGIC_NUMBER
} else {
false
}
}
}
impl Message for SslRequest {
#[inline]
fn message_type() -> Option<u8> {
None
}
#[inline]
fn message_length(&self) -> usize {
Self::BODY_SIZE
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
buf.put_i32(Self::BODY_MAGIC_NUMBER);
Ok(())
}
fn decode_body(
_buf: &mut BytesMut,
_full_len: usize,
_ctx: &DecodeContext,
) -> PgWireResult<Self> {
unreachable!();
}
fn decode(buf: &mut BytesMut, _ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
if buf.remaining() >= Self::BODY_SIZE {
if Self::is_ssl_request_packet(buf) {
buf.advance(8);
Ok(Some(SslRequest))
} else {
Err(PgWireError::InvalidSSLRequestMessage)
}
} else {
Ok(None)
}
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct GssEncRequest;
impl GssEncRequest {
pub const BODY_MAGIC_NUMBER: i32 = 80877104;
pub const BODY_SIZE: usize = 8;
pub fn is_gss_enc_request_packet(buf: &[u8]) -> bool {
if buf.remaining() >= Self::BODY_SIZE {
let magic_code = (&buf[4..8]).get_i32();
magic_code == Self::BODY_MAGIC_NUMBER
} else {
false
}
}
}
impl Message for GssEncRequest {
#[inline]
fn message_type() -> Option<u8> {
None
}
#[inline]
fn message_length(&self) -> usize {
Self::BODY_SIZE
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
buf.put_i32(Self::BODY_MAGIC_NUMBER);
Ok(())
}
fn decode_body(
_buf: &mut BytesMut,
_full_len: usize,
_ctx: &DecodeContext,
) -> PgWireResult<Self> {
unreachable!();
}
fn decode(buf: &mut BytesMut, _ctx: &DecodeContext) -> PgWireResult<Option<Self>> {
if buf.remaining() >= Self::BODY_SIZE {
if Self::is_gss_enc_request_packet(buf) {
buf.advance(8);
Ok(Some(GssEncRequest))
} else {
Err(PgWireError::InvalidGssEncRequestMessage)
}
} else {
Ok(None)
}
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct SASLInitialResponse {
pub auth_method: String,
pub data: Option<Bytes>,
}
impl Message for SASLInitialResponse {
#[inline]
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
}
#[inline]
fn message_length(&self) -> usize {
4 + self.auth_method.len() + 1 + 4 + self.data.as_ref().map(|b| b.len()).unwrap_or(0)
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
codec::put_cstring(buf, &self.auth_method);
if let Some(ref data) = self.data {
buf.put_i32(data.len() as i32);
buf.put_slice(data.as_ref());
} else {
buf.put_i32(-1);
}
Ok(())
}
fn decode_body(
buf: &mut BytesMut,
_full_len: usize,
_ctx: &DecodeContext,
) -> PgWireResult<Self> {
let auth_method = codec::get_cstring(buf).unwrap_or_else(|| "".to_owned());
let data_len = buf.get_i32();
let data = if data_len == -1 {
None
} else {
Some(buf.split_to(data_len as usize).freeze())
};
Ok(SASLInitialResponse { auth_method, data })
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct SASLResponse {
pub data: Bytes,
}
impl Message for SASLResponse {
#[inline]
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_PASSWORD_MESSAGE_FAMILY)
}
#[inline]
fn message_length(&self) -> usize {
4 + self.data.len()
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
buf.put_slice(self.data.as_ref());
Ok(())
}
fn decode_body(
buf: &mut BytesMut,
full_len: usize,
_ctx: &DecodeContext,
) -> PgWireResult<Self> {
let data = buf.split_to(full_len - 4).freeze();
Ok(SASLResponse { data })
}
}
#[non_exhaustive]
#[derive(PartialEq, Eq, Debug, new)]
pub struct NegotiateProtocolVersion {
pub newest_minor_protocol: i32,
pub unsupported_options: Vec<String>,
}
pub const MESSAGE_TYPE_BYTE_NEGOTIATE_PROTOCOL_VERSION: u8 = b'v';
impl Message for NegotiateProtocolVersion {
#[inline]
fn message_type() -> Option<u8> {
Some(MESSAGE_TYPE_BYTE_NEGOTIATE_PROTOCOL_VERSION)
}
#[inline]
fn max_message_length() -> usize {
super::SMALL_BACKEND_PACKET_SIZE_LIMIT
}
#[inline]
fn message_length(&self) -> usize {
12 + self
.unsupported_options
.iter()
.map(|s| s.len() + 1)
.sum::<usize>()
}
fn encode_body(&self, buf: &mut BytesMut) -> PgWireResult<()> {
buf.put_i32(self.newest_minor_protocol);
buf.put_i32(self.unsupported_options.len() as i32);
for s in &self.unsupported_options {
codec::put_cstring(buf, s);
}
Ok(())
}
fn decode_body(
buf: &mut BytesMut,
_full_len: usize,
_ctx: &DecodeContext,
) -> PgWireResult<Self> {
let version = buf.get_i32();
let option_count = buf.get_i32();
let mut options = Vec::with_capacity(option_count as usize);
for _ in 0..option_count {
options.push(codec::get_cstring(buf).unwrap_or_else(|| "".to_owned()))
}
Ok(Self {
newest_minor_protocol: version,
unsupported_options: options,
})
}
}