use bytes::{Buf, BytesMut};
use lru::LruCache;
use std::{
future::Ready,
io,
num::NonZeroUsize,
task::{Context, Poll, ready},
time::Instant,
};
use crate::{
Result,
common::{span, verbose},
executor::Executor,
net::Socket,
phase,
postgres::{
BackendProtocol, ErrorResponse, FrontendProtocol, NoticeResponse, backend, frontend,
},
statement::StatementName,
transport::{PgTransport, PgTransportExt},
};
mod config;
pub use config::{Config, ParseError};
const DEFAULT_BUF_CAPACITY: usize = 1024;
const DEFAULT_PREPARED_STMT_CACHE: NonZeroUsize = NonZeroUsize::new(24).unwrap();
#[derive(Debug)]
pub struct Connection {
socket: Socket,
read_buf: BytesMut,
write_buf: BytesMut,
stmts: LruCache<u64, StatementName>,
connected_at: Instant,
sync_pending: usize,
backend_key: backend::BackendKeyData,
}
impl Connection {
pub fn connect_env() -> impl Future<Output = Result<Connection>> {
Self::connect_with(Config::from_env())
}
pub async fn connect(url: &str) -> Result<Self> {
Self::connect_with(Config::parse(url)?).await
}
pub async fn connect_with(config: Config) -> Result<Self> {
let socket = if cfg!(unix) && config.host == "localhost" {
let socket = Socket::connect_socket(&(format!("/run/postgresql/.s.PGSQL.{}",config.port))).await;
match socket {
Ok(ok) => ok,
Err(_) => Socket::connect_tcp(&config.host, config.port).await?,
}
} else {
Socket::connect_tcp(&config.host, config.port).await?
};
let mut me = Self {
socket,
read_buf: BytesMut::with_capacity(DEFAULT_BUF_CAPACITY),
write_buf: BytesMut::with_capacity(DEFAULT_BUF_CAPACITY),
stmts: LruCache::new(DEFAULT_PREPARED_STMT_CACHE),
connected_at: Instant::now(),
backend_key: backend::BackendKeyData { process_id: 0, secret_key: 0 },
sync_pending: 0,
};
let res = phase::startup(&config, &mut me).await?;
me.backend_key = res.backend_key_data;
Ok(me)
}
}
impl Connection {
pub fn connected_at(&self) -> Instant {
self.connected_at
}
pub fn backend_key(&self) -> backend::BackendKeyData {
self.backend_key
}
}
impl Connection {
pub fn poll_shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
self.socket.poll_shutdown(cx)
}
pub async fn close(mut self) -> io::Result<()> {
self.send(frontend::Terminate);
self.flush().await?;
self.socket.shutdown().await
}
}
macro_rules! poll_message {
(
poll($io:ident, $cx:ident);
let $msgtype:ident;
let $body:ident;
) => {
let Some(mut header) = $io.read_buf.get(..5) else {
$io.read_buf.reserve(1024);
ready!(crate::io::poll_read(&mut $io.socket, &mut $io.read_buf, $cx)?);
continue;
};
let $msgtype = header.get_u8();
let len = header.get_i32() as _;
if $io.read_buf.len() - 1 < len {
$io.read_buf.reserve(1 + len);
ready!(crate::io::poll_read(&mut $io.socket, &mut $io.read_buf, $cx)?);
continue;
}
$io.read_buf.advance(5);
let $body = $io.read_buf.split_to(len - 4).freeze();
verbose!("(B){:?}",backend::BackendMessage::decode($msgtype, $body.clone()).unwrap());
};
}
impl Connection {
pub fn ready(&mut self) -> impl Future<Output = Result<()>> {
std::future::poll_fn(|cx|self.poll_ready(cx))
}
pub(crate) fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<()>> {
if !self.write_buf.is_empty() {
ready!(self.poll_flush(cx)?)
}
while self.sync_pending != 0 {
verbose!(self.sync_pending,"healthcheck");
poll_message! {
poll(self, cx);
let msgtype;
let _body;
}
match msgtype {
ErrorResponse::MSGTYPE => {
self.send(frontend::Sync);
self.ready_request();
#[cfg(feature = "log")]
log::error!("{}",ErrorResponse::new(_body));
},
NoticeResponse::MSGTYPE => {
#[cfg(feature = "log")]
log::warn!("{}",NoticeResponse::new(_body));
},
backend::ParameterStatus::MSGTYPE => {
}
backend::ReadyForQuery::MSGTYPE => {
self.sync_pending -= 1;
},
_ => {} }
}
Poll::Ready(Ok(()))
}
}
impl PgTransport for Connection {
fn poll_flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
crate::io::poll_write_all(&mut self.socket, &mut self.write_buf, cx)
}
fn poll_recv<B: BackendProtocol>(&mut self, cx: &mut Context) -> Poll<Result<B>> {
ready!(self.poll_ready(cx)?);
loop {
poll_message! {
poll(self, cx);
let msgtype;
let body;
}
match msgtype {
ErrorResponse::MSGTYPE => {
self.send(frontend::Sync);
self.ready_request();
Err(ErrorResponse::new(body))?
},
NoticeResponse::MSGTYPE => {
#[cfg(feature = "log")]
log::warn!("{}",NoticeResponse::new(body));
continue;
},
backend::ParameterStatus::MSGTYPE => {
}
_ => return Poll::Ready(Ok(B::decode(msgtype, body)?)),
}
}
}
fn ready_request(&mut self) {
self.sync_pending += 1;
}
fn send<F: FrontendProtocol>(&mut self, message: F) {
verbose!(?message,"(F)");
frontend::write(message, &mut self.write_buf);
}
fn send_startup(&mut self, startup: frontend::Startup) {
verbose!(?startup,"(F)");
startup.write(&mut self.write_buf);
}
fn get_stmt(&mut self, sqlid: u64) -> Option<StatementName> {
self.stmts.get(&sqlid).cloned().inspect(|_name|{
span!("statement");
verbose!(name=%_name,"cache hit")
})
}
fn add_stmt(&mut self, id: u64, name: StatementName) {
span!("statement");
verbose!(%name,"added");
if let Some((_id,name)) = self.stmts.push(id, name) {
verbose!(%name,"removed");
self.send(frontend::Close {
variant: b'S',
name: name.as_str(),
});
self.send(frontend::Sync);
self.ready_request();
}
}
}
impl Executor for Connection {
type Transport = Self;
type Future = Ready<Result<Self::Transport>>;
fn connection(self) -> Self::Future {
std::future::ready(Ok(self))
}
}