use std::fmt::Debug;
use std::ops::DerefMut;
use std::sync::Arc;
use bytes::Bytes;
use futures::stream::StreamExt;
use postgres_types::FromSqlOwned;
use tokio::sync::Mutex;
use crate::api::Type;
use crate::api::results::QueryResponse;
use crate::error::{PgWireError, PgWireResult};
use crate::messages::data::FORMAT_CODE_BINARY;
use crate::messages::extendedquery::Bind;
use crate::types::FromSqlText;
use crate::types::format::FormatOptions;
use super::DEFAULT_NAME;
use super::results::FieldFormat;
use super::stmt::StoredStatement;
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct Portal<S> {
pub name: String,
pub statement: Arc<StoredStatement<S>>,
pub parameter_format: Format,
pub parameters: Vec<Option<Bytes>>,
pub result_column_format: Format,
pub state: Arc<Mutex<PortalExecutionState>>,
}
#[derive(Default, Debug)]
pub enum PortalExecutionState {
#[default]
Initial,
Suspended(QueryResponse),
Finished,
}
#[derive(Debug)]
pub struct FetchResult {
pub response: QueryResponse,
pub suspended: bool,
}
#[derive(Debug, Clone, Default)]
pub enum Format {
#[default]
UnifiedText,
UnifiedBinary,
Individual(Vec<i16>),
}
impl From<i16> for Format {
fn from(v: i16) -> Format {
if v == FORMAT_CODE_BINARY {
Format::UnifiedBinary
} else {
Format::UnifiedText
}
}
}
impl Format {
pub fn format_for(&self, idx: usize) -> FieldFormat {
match self {
Format::UnifiedText => FieldFormat::Text,
Format::UnifiedBinary => FieldFormat::Binary,
Format::Individual(fv) => FieldFormat::from(fv[idx]),
}
}
pub fn is_text(&self, idx: usize) -> bool {
self.format_for(idx) == FieldFormat::Text
}
pub fn is_binary(&self, idx: usize) -> bool {
self.format_for(idx) == FieldFormat::Binary
}
fn from_codes(codes: &[i16]) -> Self {
if codes.is_empty() {
Format::UnifiedText
} else if codes.len() == 1 {
Format::from(codes[0])
} else {
Format::Individual(codes.to_vec())
}
}
}
impl<S: Clone> Portal<S> {
pub fn try_new(bind: &Bind, statement: Arc<StoredStatement<S>>) -> PgWireResult<Self> {
let portal_name = bind
.portal_name
.clone()
.unwrap_or_else(|| DEFAULT_NAME.to_owned());
let param_format = Format::from_codes(&bind.parameter_format_codes);
let result_format = Format::from_codes(&bind.result_column_format_codes);
Ok(Portal {
name: portal_name,
statement,
parameter_format: param_format,
parameters: bind.parameters.clone(),
result_column_format: result_format,
state: Arc::new(Mutex::new(PortalExecutionState::Initial)),
})
}
pub fn new_cursor(name: String, statement: Arc<StoredStatement<S>>) -> Self {
Portal {
name,
statement,
parameter_format: Format::UnifiedText,
parameters: vec![],
result_column_format: Format::UnifiedText,
state: Arc::new(Mutex::new(PortalExecutionState::Initial)),
}
}
pub fn parameter_len(&self) -> usize {
self.parameters.len()
}
pub fn parameter<'a, T>(&'a self, idx: usize, pg_type: &Type) -> PgWireResult<Option<T>>
where
T: FromSqlOwned + FromSqlText<'a>,
{
if !T::accepts(pg_type) {
return Err(PgWireError::InvalidRustTypeForParameter(
pg_type.name().to_owned(),
));
}
let param = self
.parameters
.get(idx)
.ok_or_else(|| PgWireError::ParameterIndexOutOfBound(idx))?;
let _format = self.parameter_format.format_for(idx);
if let Some(param) = param {
if self.parameter_format.is_binary(idx) {
T::from_sql(pg_type, param)
.map(|v| Some(v))
.map_err(PgWireError::FailedToParseParameter)
} else {
T::from_sql_text(pg_type, param, &FormatOptions::default())
.map(|v| Some(v))
.map_err(PgWireError::FailedToParseParameter)
}
} else {
Ok(None)
}
}
pub fn state(&self) -> Arc<Mutex<PortalExecutionState>> {
self.state.clone()
}
pub async fn start(&self, response: QueryResponse) {
let mut state = self.state.lock().await;
*state = PortalExecutionState::Suspended(response);
}
pub async fn fetch(&self, max_rows: usize) -> PgWireResult<FetchResult> {
let mut state = self.state.lock().await;
match state.deref_mut() {
PortalExecutionState::Initial => Err(PgWireError::PortalNotStarted),
PortalExecutionState::Finished => Ok(FetchResult {
response: QueryResponse::new(Arc::new(vec![]), futures::stream::empty()),
suspended: false,
}),
PortalExecutionState::Suspended(response) => {
let command_tag = response.command_tag().to_owned();
let row_schema = response.row_schema();
let data_rows = response.data_rows();
let mut rows = Vec::with_capacity(if max_rows == 0 { 256 } else { max_rows });
let mut suspended = true;
while max_rows == 0 || rows.len() < max_rows {
if let Some(row) = data_rows.next().await {
rows.push(row?);
} else {
suspended = false;
break;
}
}
if !suspended {
*state = PortalExecutionState::Finished;
}
let result_response = QueryResponse {
command_tag,
row_schema,
data_rows: Box::pin(futures::stream::iter(rows.into_iter().map(Ok))),
};
Ok(FetchResult {
response: result_response,
suspended,
})
}
}
}
}
#[cfg(test)]
mod tests {
use postgres_types::FromSql;
use super::*;
#[test]
fn test_from_sql() {
assert_eq!(
"helloworld",
String::from_sql(&Type::UNKNOWN, "helloworld".as_bytes()).unwrap()
)
}
}