use std::sync::Arc;
use async_trait::async_trait;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::copy::NoopCopyHandler;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response};
use pgwire::api::stmt::QueryParser;
use pgwire::api::stmt::StoredStatement;
use pgwire::api::{ClientInfo, PgWireHandlerFactory, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use tokio::sync::Mutex;
use crate::datatypes::{self, into_pg_type};
pub struct HandlerFactory(pub Arc<DfSessionService>);
impl PgWireHandlerFactory for HandlerFactory {
type StartupHandler = NoopStartupHandler;
type SimpleQueryHandler = DfSessionService;
type ExtendedQueryHandler = DfSessionService;
type CopyHandler = NoopCopyHandler;
fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
self.0.clone()
}
fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
self.0.clone()
}
fn startup_handler(&self) -> Arc<Self::StartupHandler> {
Arc::new(NoopStartupHandler)
}
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
Arc::new(NoopCopyHandler)
}
}
pub struct DfSessionService {
session_context: Arc<Mutex<SessionContext>>,
parser: Arc<Parser>,
}
impl DfSessionService {
pub fn new(session_context: SessionContext) -> DfSessionService {
let session_context = Arc::new(Mutex::new(session_context));
let parser = Arc::new(Parser {
session_context: session_context.clone(),
});
DfSessionService {
session_context,
parser,
}
}
}
#[async_trait]
impl SimpleQueryHandler for DfSessionService {
async fn do_query<'a, C>(
&self,
_client: &mut C,
query: &'a str,
) -> PgWireResult<Vec<Response<'a>>>
where
C: ClientInfo + Unpin + Send + Sync,
{
if query.to_uppercase().starts_with("SELECT") {
let ctx = self.session_context.lock().await;
let df = ctx
.sql(query)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?;
Ok(vec![Response::Query(resp)])
} else {
Ok(vec![Response::Error(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
"Only select statements is supported by this tool.".to_owned(),
)))])
}
}
}
pub struct Parser {
session_context: Arc<Mutex<SessionContext>>,
}
#[async_trait]
impl QueryParser for Parser {
type Statement = LogicalPlan;
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
let context = self.session_context.lock().await;
let state = context.state();
let logical_plan = state
.create_logical_plan(sql)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let optimised = state
.optimize(&logical_plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Ok(optimised)
}
}
#[async_trait]
impl ExtendedQueryHandler for DfSessionService {
type Statement = LogicalPlan;
type QueryParser = Parser;
fn query_parser(&self) -> Arc<Self::QueryParser> {
self.parser.clone()
}
async fn do_describe_statement<C>(
&self,
_client: &mut C,
target: &StoredStatement<Self::Statement>,
) -> PgWireResult<DescribeStatementResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let plan = &target.statement;
let schema = plan.schema();
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?;
let params = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
dbg!(¶ms);
let mut param_types = Vec::with_capacity(params.len());
for param_type in params.into_values() {
if let Some(datatype) = param_type {
let pgtype = into_pg_type(&datatype)?;
param_types.push(pgtype);
} else {
param_types.push(Type::UNKNOWN);
}
}
Ok(DescribeStatementResponse::new(param_types, fields))
}
async fn do_describe_portal<C>(
&self,
_client: &mut C,
target: &Portal<Self::Statement>,
) -> PgWireResult<DescribePortalResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let plan = &target.statement.statement;
let format = &target.result_column_format;
let schema = plan.schema();
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?;
Ok(DescribePortalResponse::new(fields))
}
async fn do_query<'a, C>(
&self,
_client: &mut C,
portal: &'a Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response<'a>>
where
C: ClientInfo + Unpin + Send + Sync,
{
let plan = &portal.statement.statement;
let param_values = datatypes::deserialize_parameters(
portal,
&plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
.values()
.map(|v| v.as_ref())
.collect::<Vec<Option<&DataType>>>(),
)?;
let plan = plan
.clone()
.replace_params_with_values(¶m_values)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let dataframe = self
.session_context
.lock()
.await
.execute_logical_plan(plan)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let resp = datatypes::encode_dataframe(dataframe, &portal.result_column_format).await?;
Ok(Response::Query(resp))
}
}