use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, Ordering};
use super::client_options::ClientOptions;
use super::params::BinaryParam;
use crate::PgToPlError;
use crate::models::column_result::{
ColumnStorage, clone_storages, column_from_field, column_to_series, push_column_value,
};
use crate::models::params::format_params;
use crate::utils::error::{MessageX, PgToPlResult};
use crate::utils::{md5_hash, statement_name};
use bytes::{BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use polars::prelude::*;
use postgres_protocol::IsNull;
use postgres_protocol::authentication::sasl::{ChannelBinding, SCRAM_SHA_256, ScramSha256};
use postgres_protocol::message::backend;
use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::{sasl_initial_response, sasl_response};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::{debug, error, warn};
#[derive(Debug, Clone)]
struct PreparedStatementInfo {
param_types: Vec<u32>,
columns: Vec<ColumnStorage>,
}
pub struct Client {
healthy: AtomicBool,
ready: AtomicBool,
options: ClientOptions,
stream: Mutex<TcpStream>,
prepared_statements: Mutex<HashMap<String, PreparedStatementInfo>>,
portal_count: Mutex<i32>,
monkey_chaos_already_prepare: bool,
}
impl Client {
pub async fn new(options: ClientOptions) -> PgToPlResult<Self> {
let stream = TcpStream::connect(options.connect_url()).await?;
Ok(Client {
monkey_chaos_already_prepare: options.monkey_chaos_already_prepare,
healthy: AtomicBool::new(false),
ready: AtomicBool::new(false),
options,
stream: Mutex::new(stream),
prepared_statements: Mutex::new(HashMap::new()),
portal_count: Mutex::new(0),
})
}
pub fn with_monkey_chaos_already_prepare(mut self) -> Self {
self.monkey_chaos_already_prepare = true;
self
}
pub async fn connect(&self) -> PgToPlResult<()> {
let mut stream = self.stream.lock().await;
self._connect(&mut stream).await
}
pub async fn _connect(&self, stream: &mut TcpStream) -> PgToPlResult<()> {
let mut buf = BytesMut::new(); frontend::startup_message(
[
("user", self.options.user.as_str()),
("database", self.options.database.as_str()),
],
&mut buf,
)?;
stream.write_all(&buf).await?;
let mut read_buffer = BytesMut::with_capacity(8192);
let mut error_to_return: Option<PgToPlError> = None;
let mut done = false;
let mut scram_state: Option<ScramSha256> = None;
while !done {
read_buffer.reserve(8192);
let n = {
read_buffer.reserve(8192);
let dst = read_buffer.chunk_mut();
let buf: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(dst.as_mut_ptr(), dst.len()) };
let n = stream.read(buf).await?;
unsafe {
read_buffer.advance_mut(n);
}
n
};
if n == 0 {
self.mark_unhealthy();
return Err(PgToPlError::ConnectionClosed);
}
loop {
let message = backend::Message::parse(&mut read_buffer);
match message {
Ok(Some(backend::Message::ReadyForQuery(_))) => {
done = true;
if error_to_return.is_some() {
self.mark_unhealthy();
}
break;
}
Ok(Some(backend::Message::ErrorResponse(error))) => {
self.mark_unhealthy();
if error_to_return.is_none() {
error_to_return = Some(error.into());
}
}
Ok(Some(backend::Message::AuthenticationCleartextPassword)) => {
warn!("Authentication: Cleartext password requested");
}
Ok(Some(backend::Message::AuthenticationSasl(body))) => {
let mechanisms: Vec<&str> = body.mechanisms().collect()?;
if !mechanisms.contains(&SCRAM_SHA_256) {
self.mark_unhealthy();
error_to_return = Some(PgToPlError::UnsupportedSaslMechanism {
offered: mechanisms.iter().map(|s| s.to_string()).collect(),
});
continue;
}
let scram = ScramSha256::new(
self.options.password.as_bytes(),
ChannelBinding::unrequested(),
);
let mut buf = BytesMut::new();
sasl_initial_response(SCRAM_SHA_256, scram.message(), &mut buf)?;
stream.write_all(&buf).await?;
scram_state = Some(scram);
}
Ok(Some(backend::Message::AuthenticationSaslContinue(body))) => {
if let Some(ref mut scram) = scram_state {
scram
.update(body.data())
.map_err(PgToPlError::SaslAuthenticationFailed)?;
let mut buf = BytesMut::new();
sasl_response(scram.message(), &mut buf)?;
stream.write_all(&buf).await?;
} else {
self.mark_unhealthy();
error_to_return = Some(PgToPlError::SaslStateError {
message: "AuthenticationSaslContinue",
});
}
}
Ok(Some(backend::Message::AuthenticationSaslFinal(body))) => {
if let Some(ref mut scram) = scram_state {
scram
.finish(body.data())
.map_err(PgToPlError::SaslAuthenticationFailed)?;
debug!("SASL SCRAM-SHA-256 authentication successful");
} else {
self.mark_unhealthy();
error_to_return = Some(PgToPlError::SaslStateError {
message: "AuthenticationSaslFinal",
});
}
}
Ok(Some(backend::Message::AuthenticationMd5Password(salt))) => {
let mut buf = BytesMut::new(); frontend::password_message(
md5_hash(
self.options.user.as_str(),
self.options.password.as_str(),
&salt.salt(),
)
.as_bytes(),
&mut buf,
)?;
stream.write_all(&buf).await?;
}
Ok(Some(backend::Message::ParameterStatus(body))) => {
let name = match body.name() {
Ok(name) => name,
Err(_) => "fail to parse parameter name",
};
let value = match body.value() {
Ok(value) => value,
Err(_) => "fail to parse parameter value",
};
debug!("Parameter status: {} = {}", name, value);
}
Ok(Some(backend::Message::AuthenticationOk))
| Ok(Some(backend::Message::BackendKeyData(_))) => {
}
Ok(Some(other)) => {
warn!("[Connect] Unhandled message: {:?}", other.message_name());
}
Ok(None) => {
break;
}
Err(e) => {
error_to_return = Some(e.into());
break;
}
}
}
}
if let Some(error_to_return) = error_to_return {
return Err(error_to_return);
} else {
self.healthy.store(true, Ordering::Relaxed);
self.ready.store(true, Ordering::Relaxed);
Ok(())
}
}
pub async fn reset(&self, stream: &mut TcpStream) -> PgToPlResult<()> {
warn!("Resetting client connection");
let _ = stream.shutdown().await;
{
let _ = stream.shutdown().await;
*stream = TcpStream::connect(self.options.connect_url()).await?;
self.healthy.store(false, Ordering::SeqCst);
self.ready.store(false, Ordering::SeqCst);
self.prepared_statements.lock().await.clear();
*self.portal_count.lock().await = 0;
}
self._connect(stream).await?;
Ok(())
}
pub async fn query<P>(&self, query: &str, params: P) -> PgToPlResult<DataFrame>
where
P: IntoIterator<Item = Option<BinaryParam>> + Clone + Debug,
{
if self.has_broken() {
return Err(PgToPlError::ConnectionBroken);
}
let mut stream = self.stream.lock().await;
if !self.ready.load(Ordering::Relaxed) {
self.reset(&mut stream).await?;
}
self.ready.store(false, Ordering::Relaxed);
let res = self._query(query, params.clone(), &mut stream).await;
if let Err(error) = &res {
self.mark_unhealthy();
error!(
"Failed to execute query '{}' with params: {:?}: {}",
query, params, error
);
}
if !self.has_broken() {
self.ready.store(true, Ordering::Relaxed);
}
res
}
pub async fn _query<P>(
&self,
query: &str,
params: P,
stream: &mut TcpStream,
) -> PgToPlResult<DataFrame>
where
P: IntoIterator<Item = Option<BinaryParam>> + Clone,
{
if self.has_broken() {
self.mark_unhealthy();
return Err(PgToPlError::ConnectionBroken);
}
let portal_count = {
let mut count = self.portal_count.lock().await;
*count += 1;
*count
};
let portal_name = format!("portal_{}", portal_count);
let (param_types, param_values) = format_params(params.clone());
let name = statement_name(query);
let mut prepared_statements = self.prepared_statements.lock().await;
let (prepare, mut columns) = match prepared_statements.get(&name) {
Some(info) => {
if info.param_types != param_types {
return Err(PgToPlError::ParamTypeMismatch);
}
if self.monkey_chaos_already_prepare && false {
debug!("Monkey chaos already prepare");
(true, Vec::new())
} else {
(false, info.columns.clone())
}
}
None => (true, Vec::new()),
};
if prepare {
columns = self
.prepare_query(&name, query, ¶m_types, stream)
.await?;
prepared_statements.insert(
name.clone(),
PreparedStatementInfo {
param_types: param_types.clone(),
columns: clone_storages(&columns),
},
);
}
{
let mut buf = BytesMut::new(); frontend::bind(
&portal_name,
&name,
std::iter::repeat(1).take(param_values.len()), param_values.iter(),
|val, buf| match val {
Some(bytes) => {
buf.put_slice(bytes);
Ok(IsNull::No)
}
None => Ok(IsNull::Yes),
},
[1],
&mut buf,
)
.map_err(|_| PgToPlError::BindError)?;
frontend::execute(&portal_name, 0, &mut buf)?;
frontend::sync(&mut buf);
stream.write_all(&buf).await?;
}
let mut error_to_return: Option<PgToPlError> = None;
let mut nb_rows = 0;
{
let mut read_buffer = BytesMut::with_capacity(8192);
let mut done = false;
while !done {
let n = {
read_buffer.reserve(8192);
let dst = read_buffer.chunk_mut();
let buf: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(dst.as_mut_ptr(), dst.len()) };
let n = stream.read(buf).await?;
unsafe {
read_buffer.advance_mut(n);
}
n
};
if n == 0 {
self.mark_unhealthy();
return Err(PgToPlError::ConnectionClosed);
}
loop {
let message = backend::Message::parse(&mut read_buffer);
match message {
Ok(Some(backend::Message::DataRow(row))) => {
let buf = row.buffer();
let mut ranges = row.ranges();
for (i, col) in columns.iter_mut().enumerate() {
let next = ranges.next(); match next {
Ok(Some(Some(r))) => {
let res = push_column_value(col, Some(&buf[r]));
if let Err(e) = res {
self.mark_unhealthy();
error_to_return = Some(e);
}
}
Ok(Some(None)) => {
let res = push_column_value(col, None);
if let Err(e) = res {
self.mark_unhealthy();
error_to_return = Some(e);
}
}
Ok(None) => {
self.mark_unhealthy();
error_to_return =
Some(PgToPlError::TooFewField(columns.len(), i));
break;
}
Err(e) => {
self.mark_unhealthy();
error_to_return = Some(e.into());
}
}
}
if error_to_return.is_none() && ranges.next()?.is_some() {
self.mark_unhealthy();
error_to_return = Some(PgToPlError::TooManyField(columns.len()));
}
nb_rows += 1;
}
Ok(Some(backend::Message::ReadyForQuery(_))) => {
done = true;
if error_to_return.is_some() {
self.mark_unhealthy();
}
}
Ok(Some(backend::Message::ErrorResponse(error))) => {
self.mark_unhealthy();
if error_to_return.is_none() {
error_to_return = Some(error.into());
}
}
Ok(Some(backend::Message::CommandComplete(body))) => match body.tag() {
Ok(tag) => {
debug!("Command completed: {}", tag);
}
Err(err) => {
warn!("Error parsing command tag: {}", err);
}
},
Ok(Some(backend::Message::BindComplete)) => {}
Ok(Some(backend::Message::EmptyQueryResponse)) => {
debug!("Empty query response");
}
Ok(Some(other)) => {
warn!("[Read] Unhandled message: {:?}", other.message_name());
}
Ok(None) => {
break;
}
Err(e) => {
error_to_return = Some(e.into());
break;
}
}
}
}
}
if let Some(error_to_return) = error_to_return {
self.mark_unhealthy();
Err(error_to_return)
} else {
drop(prepared_statements);
let series = columns
.into_iter()
.map(|col| column_to_series(col))
.collect::<PgToPlResult<Vec<_>>>()?;
Ok(DataFrame::new(nb_rows, series)?)
}
}
pub fn has_broken(&self) -> bool {
!self.healthy.load(Ordering::Relaxed)
}
fn mark_unhealthy(&self) {
self.healthy.store(false, Ordering::Relaxed);
}
pub async fn ping(&self) -> PgToPlResult<()> {
if self.has_broken() {
return Err(PgToPlError::ConnectionBroken);
}
let mut stream = self.stream.lock().await;
let mut buf = BytesMut::new();
frontend::query("/* ping */ SELECT 1;", &mut buf)?;
stream.write_all(&buf).await?;
let mut read_buffer = BytesMut::with_capacity(4096);
let mut error_to_return: Option<PgToPlError> = None;
let mut done = false;
while !done {
read_buffer.reserve(4096);
let dst = read_buffer.chunk_mut();
let buf: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(dst.as_mut_ptr(), dst.len()) };
let n = stream.read(buf).await?;
unsafe {
read_buffer.advance_mut(n);
}
if n == 0 {
self.mark_unhealthy();
return Err(PgToPlError::ConnectionClosed);
}
loop {
let message = backend::Message::parse(&mut read_buffer);
match message {
Ok(Some(backend::Message::ReadyForQuery(_))) => {
done = true;
break;
}
Ok(Some(backend::Message::ErrorResponse(error))) => {
self.mark_unhealthy();
if error_to_return.is_none() {
error_to_return = Some(error.into());
}
}
Ok(Some(backend::Message::CommandComplete(body))) => match body.tag() {
Ok(tag) => {
debug!("Command completed: {}", tag);
}
Err(err) => {
warn!("Error parsing command tag: {}", err);
}
},
Ok(Some(backend::Message::RowDescription(_)))
| Ok(Some(backend::Message::DataRow(_))) => {
}
Ok(Some(other)) => {
warn!("[Ping] Unhandled message: {:?}", other.message_name());
}
Ok(None) => {
break;
}
Err(e) => {
error_to_return = Some(e.into());
break;
}
}
}
}
if let Some(err) = error_to_return {
self.mark_unhealthy();
Err(err)
} else {
Ok(())
}
}
async fn prepare_query(
&self,
name: &str,
query: &str,
param_types: &Vec<u32>,
stream: &mut TcpStream,
) -> PgToPlResult<Vec<ColumnStorage>> {
let res = self._prepare_query(name, query, param_types, stream).await;
if let Ok(columns) = res {
Ok(columns)
} else {
self.close_statement(name, stream).await?;
self._prepare_query(name, query, param_types, stream).await
}
}
async fn _prepare_query(
&self,
name: &str,
query: &str,
param_types: &Vec<u32>,
stream: &mut TcpStream,
) -> PgToPlResult<Vec<ColumnStorage>> {
let mut buf = BytesMut::new();
let mut read_buffer = BytesMut::with_capacity(4096);
frontend::parse(&name, query, param_types.iter().copied(), &mut buf)?;
frontend::describe(b'S', &name, &mut buf)?;
frontend::sync(&mut buf);
stream.write_all(&buf).await?;
let mut done = false;
let mut error_to_return: Option<PgToPlError> = None;
let mut columns = vec![];
while !done {
let n = {
read_buffer.reserve(8192);
let dst = read_buffer.chunk_mut();
let buf: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(dst.as_mut_ptr(), dst.len()) };
let n = stream.read(buf).await?;
unsafe {
read_buffer.advance_mut(n);
}
n
};
if n == 0 {
self.mark_unhealthy();
return Err(PgToPlError::ConnectionClosed);
}
loop {
let message = backend::Message::parse(&mut read_buffer);
match message {
Ok(Some(backend::Message::RowDescription(desc))) => {
columns.clear();
let fields = desc.fields().iterator();
for field in fields {
let f = field?;
columns.push(column_from_field(&f))
}
}
Ok(Some(backend::Message::ReadyForQuery(_))) => {
done = true;
break;
}
Ok(Some(backend::Message::ErrorResponse(error))) => {
if error_to_return.is_none() {
error_to_return = Some(error.into());
}
}
Ok(Some(backend::Message::ParameterDescription(body))) => {
let parameters = body.parameters().iterator();
let mut index = 0;
for parameter in parameters {
if let Ok(parameter) = parameter {
if let Some(param_type) = param_types.get(index) {
if parameter != *param_type {
warn!(
"Parameter type mismatch for stmt '{}': Provided {}, expected {}",
name, parameter, param_types[index]
);
}
} else {
warn!(
"Unexpected parameter type for stmt '{}'. Bad number of parameters will occur. Expected parameter {} at index {}",
name, parameter, index
);
}
} else {
warn!("Failed to parse parameter description for stmt '{}'", name);
}
index += 1;
}
if index != param_types.len() {
warn!(
"Parameter description mismatch for stmt '{}': Provided {}, expected {}",
name,
param_types.len(),
index,
);
}
}
Ok(Some(backend::Message::ParseComplete)) => {
debug!("Statement '{}' parsed successfully", name);
}
Ok(Some(other)) => {
warn!("[Prepare] Unhandled message: {:?}", other.message_name());
}
Ok(None) => {
break;
}
Err(e) => {
error_to_return = Some(e.into());
break;
}
}
}
}
if let Some(error_to_return) = error_to_return {
self.mark_unhealthy();
Err(error_to_return)
} else {
Ok(columns)
}
}
pub async fn close_statement(&self, name: &str, stream: &mut TcpStream) -> PgToPlResult<()> {
let mut buf = BytesMut::new();
frontend::close(b'S', name, &mut buf)?;
frontend::sync(&mut buf);
stream.write_all(&buf).await?;
let mut read_buffer = BytesMut::with_capacity(4096);
let mut error_to_return: Option<PgToPlError> = None;
let mut done = false;
while !done {
let n = {
read_buffer.reserve(4096);
let dst = read_buffer.chunk_mut();
let buf: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(dst.as_mut_ptr(), dst.len()) };
let n = stream.read(buf).await?;
unsafe {
read_buffer.advance_mut(n);
}
n
};
if n == 0 {
return Err(PgToPlError::ConnectionClosed);
}
loop {
let message = backend::Message::parse(&mut read_buffer);
match message {
Ok(Some(backend::Message::CloseComplete)) => {}
Ok(Some(backend::Message::ReadyForQuery(_))) => {
done = true;
break;
}
Ok(Some(backend::Message::ErrorResponse(error))) => {
error_to_return = Some(error.into());
}
Ok(Some(other)) => {
warn!("[CloseStmt] Unhandled message: {:?}", other.message_name());
}
Ok(None) => {
break;
}
Err(e) => {
error_to_return = Some(e.into());
break;
}
}
}
}
if let Some(error_to_return) = error_to_return {
self.mark_unhealthy();
Err(error_to_return)
} else {
Ok(())
}
}
}