use std::fmt::Debug;
use std::sync::Arc;
use bytes::Bytes;
use postgres_types::FromSqlOwned;
use tokio::sync::Mutex;
use crate::api::Type;
use crate::api::results::QueryResponse;
use crate::error::{PgWireError, PgWireResult};
use crate::messages::data::FORMAT_CODE_BINARY;
use crate::messages::extendedquery::Bind;
use crate::types::FromSqlText;
use crate::types::format::FormatOptions;
use super::DEFAULT_NAME;
use super::results::FieldFormat;
use super::stmt::StoredStatement;
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct Portal<S> {
pub name: String,
pub statement: Arc<StoredStatement<S>>,
pub parameter_format: Format,
pub parameters: Vec<Option<Bytes>>,
pub result_column_format: Format,
pub state: Arc<Mutex<PortalExecutionState>>,
}
#[derive(Default, Debug)]
pub enum PortalExecutionState {
#[default]
Initial,
Suspended(QueryResponse),
Finished,
}
#[derive(Debug, Clone, Default)]
pub enum Format {
#[default]
UnifiedText,
UnifiedBinary,
Individual(Vec<i16>),
}
impl From<i16> for Format {
fn from(v: i16) -> Format {
if v == FORMAT_CODE_BINARY {
Format::UnifiedBinary
} else {
Format::UnifiedText
}
}
}
impl Format {
pub fn format_for(&self, idx: usize) -> FieldFormat {
match self {
Format::UnifiedText => FieldFormat::Text,
Format::UnifiedBinary => FieldFormat::Binary,
Format::Individual(fv) => FieldFormat::from(fv[idx]),
}
}
pub fn is_text(&self, idx: usize) -> bool {
self.format_for(idx) == FieldFormat::Text
}
pub fn is_binary(&self, idx: usize) -> bool {
self.format_for(idx) == FieldFormat::Binary
}
fn from_codes(codes: &[i16]) -> Self {
if codes.is_empty() {
Format::UnifiedText
} else if codes.len() == 1 {
Format::from(codes[0])
} else {
Format::Individual(codes.to_vec())
}
}
}
impl<S: Clone> Portal<S> {
pub fn try_new(bind: &Bind, statement: Arc<StoredStatement<S>>) -> PgWireResult<Self> {
let portal_name = bind
.portal_name
.clone()
.unwrap_or_else(|| DEFAULT_NAME.to_owned());
let param_format = Format::from_codes(&bind.parameter_format_codes);
let result_format = Format::from_codes(&bind.result_column_format_codes);
Ok(Portal {
name: portal_name,
statement,
parameter_format: param_format,
parameters: bind.parameters.clone(),
result_column_format: result_format,
state: Arc::new(Mutex::new(PortalExecutionState::Initial)),
})
}
pub fn parameter_len(&self) -> usize {
self.parameters.len()
}
pub fn parameter<'a, T>(&'a self, idx: usize, pg_type: &Type) -> PgWireResult<Option<T>>
where
T: FromSqlOwned + FromSqlText<'a>,
{
if !T::accepts(pg_type) {
return Err(PgWireError::InvalidRustTypeForParameter(
pg_type.name().to_owned(),
));
}
let param = self
.parameters
.get(idx)
.ok_or_else(|| PgWireError::ParameterIndexOutOfBound(idx))?;
let _format = self.parameter_format.format_for(idx);
if let Some(param) = param {
if self.parameter_format.is_binary(idx) {
T::from_sql(pg_type, param)
.map(|v| Some(v))
.map_err(PgWireError::FailedToParseParameter)
} else {
T::from_sql_text(pg_type, param, &FormatOptions::default())
.map(|v| Some(v))
.map_err(PgWireError::FailedToParseParameter)
}
} else {
Ok(None)
}
}
pub fn state(&self) -> Arc<Mutex<PortalExecutionState>> {
self.state.clone()
}
}
#[cfg(test)]
mod tests {
use postgres_types::FromSql;
use super::*;
#[test]
fn test_from_sql() {
assert_eq!(
"helloworld",
String::from_sql(&Type::UNKNOWN, "helloworld".as_bytes()).unwrap()
)
}
}