use bytes::{BufMut, BytesMut};
use std::fmt;
use super::{Oid, PgFormat};
use crate::ext::{BindParams, BufMutExt, StrExt, UsizeExt};
pub fn write<F: FrontendProtocol>(msg: F, buf: &mut BytesMut) {
const PREFIX: usize = 1 + 4;
let size_hint = msg.size_hint();
buf.reserve(PREFIX + size_hint as usize);
let offset = buf.len();
buf.put_u8(F::MSGTYPE);
buf.put_u32(4 + size_hint);
msg.encode(&mut *buf);
assert_eq!(
buf.len() - offset,
PREFIX + size_hint as usize,
"Frontend message body size not equal to size hint"
);
}
pub trait FrontendProtocol: fmt::Debug {
const MSGTYPE: u8;
fn size_hint(&self) -> u32;
fn encode(self, buf: impl BufMut);
}
#[derive(Debug)]
pub struct Startup<'a> {
pub user: &'a str,
pub database: Option<&'a str>,
pub replication: Option<&'a str>,
}
impl Startup<'_> {
pub fn write(self, buf: &mut BytesMut) {
let offset = buf.len();
buf.put_u32(0);
buf.put_u32(196_608);
buf.put_nul_string("user");
buf.put_nul_string(self.user);
if let Some(db) = self.database {
buf.put_nul_string("database");
buf.put_nul_string(db);
}
if let Some(repl) = self.replication {
buf.put_nul_string("replication");
buf.put_nul_string(repl);
}
buf.put_u8(b'\0');
let mut written_buf = &mut buf[offset..];
written_buf.put_u32(written_buf.len().to_u32());
}
}
macro_rules! size_of {
($s1:tt.$f1:ident as $t1:ty, in ..$s2:tt.$f2:ident) => {
($s2.$f2 as u32 * u32::try_from(size_of::<$t1>()).expect("data type size too large for postgres"))
};
($self:tt.$field:ident) => {
u32::try_from(size_of_val(&$self.$field)).expect("data type size too large for postgres")
};
}
pub struct PasswordMessage<'a> {
pub password: &'a str,
}
impl FrontendProtocol for PasswordMessage<'_> {
const MSGTYPE: u8 = b'p';
fn size_hint(&self) -> u32 {
self.password.nul_string_len()
}
fn encode(self, mut buf: impl BufMut) {
buf.put_nul_string(self.password);
}
}
#[derive(Debug)]
pub struct Query<'a> {
pub sql: &'a str,
}
impl FrontendProtocol for Query<'_> {
const MSGTYPE: u8 = b'Q';
fn size_hint(&self) -> u32 {
self.sql.nul_string_len()
}
fn encode(self, mut buf: impl BufMut) {
buf.put_nul_string(self.sql);
}
}
pub struct Parse<'a,I> {
pub prepare_name: &'a str,
pub sql: &'a str,
pub oids_len: u16,
pub oids: I,
}
impl<I> FrontendProtocol for Parse<'_,I>
where
I: IntoIterator<Item = Oid> + fmt::Debug,
{
const MSGTYPE: u8 = b'P';
fn size_hint(&self) -> u32 {
self.prepare_name.nul_string_len()
+ self.sql.nul_string_len()
+ size_of!(self.oids_len)
+ size_of!(self.oids as Oid, in ..self.oids_len)
}
fn encode(self, mut buf: impl BufMut) {
buf.put_nul_string(self.prepare_name);
buf.put_nul_string(self.sql);
buf.put_u16(self.oids_len);
for oid in self.oids {
buf.put_u32(oid);
}
}
}
#[derive(Debug)]
pub struct Sync;
impl FrontendProtocol for Sync {
const MSGTYPE: u8 = b'S';
fn size_hint(&self) -> u32 { 0 }
fn encode(self, _: impl BufMut) { }
}
#[derive(Debug)]
pub struct Flush;
impl FrontendProtocol for Flush {
const MSGTYPE: u8 = b'H';
fn size_hint(&self) -> u32 { 0 }
fn encode(self, _: impl BufMut) { }
}
pub struct Bind<'a, ParamFmts, Params, ResultFmts> {
pub portal_name: &'a str,
pub stmt_name: &'a str,
pub param_formats_len: u16,
pub param_formats: ParamFmts,
pub params_len: u16,
pub params_size_hint: u32,
pub params: Params,
pub result_formats_len: u16,
pub result_formats: ResultFmts,
}
impl<ParamFmts, Params, ResultFmts> FrontendProtocol for Bind<'_, ParamFmts, Params, ResultFmts>
where
ParamFmts: IntoIterator<Item = PgFormat> + fmt::Debug,
Params: Iterator + ExactSizeIterator + fmt::Debug,
<Params as Iterator>::Item: BindParams,
ResultFmts: IntoIterator<Item = PgFormat> + fmt::Debug,
{
const MSGTYPE: u8 = b'B';
fn size_hint(&self) -> u32 {
self.portal_name.nul_string_len()
+ self.stmt_name.nul_string_len()
+ size_of!(self.param_formats_len)
+ size_of!(self.param_formats as u16, in ..self.param_formats_len)
+ size_of!(self.params_len)
+ self.params_size_hint
+ size_of!(self.result_formats_len)
+ size_of!(self.result_formats as u16, in ..self.result_formats_len)
}
fn encode(self, mut buf: impl BufMut) {
buf.put_nul_string(self.portal_name);
buf.put_nul_string(self.stmt_name);
buf.put_u16(self.param_formats_len);
for format in self.param_formats {
buf.put_u16(format.format_code());
}
buf.put_u16(self.params_len);
for param in self.params {
buf.put_i32(param.size());
buf.put(param);
}
buf.put_u16(self.result_formats_len);
for format in self.result_formats {
buf.put_u16(format.format_code());
}
}
}
#[derive(Debug)]
pub struct Execute<'a> {
pub portal_name: &'a str,
pub max_row: u32,
}
impl FrontendProtocol for Execute<'_> {
const MSGTYPE: u8 = b'E';
fn size_hint(&self) -> u32 {
self.portal_name.nul_string_len() + size_of!(self.max_row)
}
fn encode(self, mut buf: impl BufMut) {
buf.put_nul_string(self.portal_name);
buf.put_u32(self.max_row);
}
}
#[derive(Debug)]
pub struct Close<'a> {
pub variant: u8,
pub name: &'a str,
}
impl FrontendProtocol for Close<'_> {
const MSGTYPE: u8 = b'C';
fn size_hint(&self) -> u32 {
size_of!(self.variant) + self.name.nul_string_len()
}
fn encode(self, mut buf: impl BufMut) {
buf.put_u8(self.variant);
buf.put_nul_string(self.name);
}
}
pub struct Describe<'a> {
pub kind: u8,
pub name: &'a str,
}
impl FrontendProtocol for Describe<'_> {
const MSGTYPE: u8 = b'D';
fn size_hint(&self) -> u32 {
size_of!(self.kind) + self.name.nul_string_len()
}
fn encode(self, mut buf: impl BufMut) {
buf.put_u8(self.kind);
buf.put_nul_string(self.name);
}
}
#[derive(Debug)]
pub struct Terminate;
impl FrontendProtocol for Terminate {
const MSGTYPE: u8 = b'X';
fn size_hint(&self) -> u32 { 0 }
fn encode(self, _: impl BufMut) { }
}
impl fmt::Debug for Describe<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Describe")
.field("kind", match self.kind {
b'S' => &"Statement('S')",
b'P' => &"Portal('P')",
_ => &"unknown"
})
.field("name", &self.name)
.finish()
}
}
impl<I: fmt::Debug> fmt::Debug for Parse<'_, I> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Parse")
.field("prepare_name", &self.prepare_name)
.field("sql", &self.sql)
.field("oids", &self.oids)
.finish()
}
}
impl<ParamFmts: fmt::Debug, Params: fmt::Debug, ResultFmts: fmt::Debug> fmt::Debug
for Bind<'_, ParamFmts, Params, ResultFmts>
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Bind")
.field("portal_name", &self.portal_name)
.field("stmt_name", &self.stmt_name)
.field("param_formats", &self.param_formats)
.field("params", &self.params)
.field("result_formats", &self.result_formats)
.finish()
}
}
impl fmt::Debug for PasswordMessage<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("PasswordMessage")
.field("password", &"<REDACTED>")
.finish()
}
}