use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::connection::stream::PgConnection;
use crate::error::{Error, Result};
use crate::protocol::backend::BackendMessage;
use crate::protocol::frontend;
use crate::row::{Row, RowDescription};
use crate::types::ToSql;
use bytes::BytesMut;
static PORTAL_COUNTER: AtomicU64 = AtomicU64::new(0);
fn next_portal_name() -> String {
let id = PORTAL_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("_sp{id}")
}
pub struct Portal {
pub(crate) name: String,
pub(crate) description: Arc<RowDescription>,
pub(crate) exhausted: bool,
}
impl Portal {
pub fn is_exhausted(&self) -> bool {
self.exhausted
}
pub fn name(&self) -> &str {
&self.name
}
}
pub(crate) async fn create_portal(
conn: &mut PgConnection,
sql: &str,
params: &[&(dyn ToSql + Sync)],
) -> Result<Portal> {
let portal_name = next_portal_name();
let param_types: Vec<u32> = params.iter().map(|p| p.oid().0).collect();
let mut param_data: Vec<Option<&[u8]>> = Vec::with_capacity(params.len());
let mut param_bufs: Vec<BytesMut> = Vec::with_capacity(params.len());
for param in params {
if param.is_null() {
param_bufs.push(BytesMut::new());
param_data.push(None);
} else {
let mut buf = BytesMut::new();
param.to_sql(&mut buf)?;
param_bufs.push(buf);
param_data.push(None);
}
}
let param_refs: Vec<Option<&[u8]>> = params
.iter()
.zip(¶m_bufs)
.map(|(p, buf)| {
if p.is_null() {
None
} else {
Some(buf.as_ref() as &[u8])
}
})
.collect();
let oids: Vec<u32> = param_types;
frontend::parse(conn.write_buf(), "", sql, &oids);
frontend::bind(conn.write_buf(), &portal_name, "", ¶m_refs, &[]);
frontend::describe_portal(conn.write_buf(), &portal_name);
frontend::sync(conn.write_buf());
conn.send().await?;
match conn.recv().await? {
BackendMessage::ParseComplete => {}
BackendMessage::ErrorResponse { fields } => {
drain_until_ready(conn).await.ok();
return Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
));
}
other => {
return Err(Error::protocol(format!(
"portal: expected ParseComplete, got {other:?}"
)));
}
}
match conn.recv().await? {
BackendMessage::BindComplete => {}
BackendMessage::ErrorResponse { fields } => {
drain_until_ready(conn).await.ok();
return Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
));
}
other => {
return Err(Error::protocol(format!(
"portal: expected BindComplete, got {other:?}"
)));
}
}
let description = match conn.recv().await? {
BackendMessage::RowDescription { fields } => Arc::new(RowDescription::new(fields)),
BackendMessage::NoData => Arc::new(RowDescription::new(vec![])),
other => {
return Err(Error::protocol(format!(
"portal: expected RowDescription, got {other:?}"
)));
}
};
drain_until_ready(conn).await?;
Ok(Portal {
name: portal_name,
description,
exhausted: false,
})
}
pub(crate) async fn fetch_portal(
conn: &mut PgConnection,
portal: &mut Portal,
max_rows: i32,
) -> Result<Vec<Row>> {
if portal.exhausted {
return Ok(Vec::new());
}
frontend::execute(conn.write_buf(), &portal.name, max_rows);
frontend::sync(conn.write_buf());
conn.send().await?;
let mut rows = Vec::new();
loop {
match conn.recv().await? {
BackendMessage::DataRow { columns } => {
rows.push(Row::new(columns, Arc::clone(&portal.description)));
}
BackendMessage::PortalSuspended => {
break;
}
BackendMessage::CommandComplete { .. } => {
portal.exhausted = true;
break;
}
BackendMessage::ErrorResponse { fields } => {
drain_until_ready(conn).await.ok();
return Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
));
}
_ => {}
}
}
drain_until_ready(conn).await?;
Ok(rows)
}
pub(crate) async fn close_portal(conn: &mut PgConnection, portal: Portal) -> Result<()> {
if portal.exhausted {
return Ok(());
}
frontend::close_portal(conn.write_buf(), &portal.name);
frontend::sync(conn.write_buf());
conn.send().await?;
loop {
match conn.recv().await? {
BackendMessage::CloseComplete => break,
BackendMessage::ErrorResponse { fields } => {
drain_until_ready(conn).await.ok();
return Err(Error::server(
fields.severity,
fields.code,
fields.message,
fields.detail,
fields.hint,
fields.position,
));
}
_ => {}
}
}
drain_until_ready(conn).await?;
Ok(())
}
async fn drain_until_ready(conn: &mut PgConnection) -> Result<()> {
loop {
if matches!(conn.recv().await?, BackendMessage::ReadyForQuery { .. }) {
return Ok(());
}
}
}