use std::sync::Arc;
use async_trait::async_trait;
use pgwire::api::results::FieldInfo;
use pgwire::api::stmt::QueryParser;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::PgWireResult;
use crate::control::state::SharedState;
use super::statement::ParsedStatement;
use parser_schema::{
count_placeholders, fields_from_projection, infer_result_fields, is_dsl_statement,
parse_select_projection, result_fields_for_returning, substitute_placeholders_with_null,
};
#[path = "parser_schema.rs"]
mod parser_schema;
pub struct NodeDbQueryParser {
state: Arc<SharedState>,
}
impl NodeDbQueryParser {
pub fn new(state: Arc<SharedState>) -> Self {
Self { state }
}
fn try_infer_types(
&self,
sql: &str,
client_types: &[Option<Type>],
tenant_id: u64,
) -> (Vec<Option<Type>>, Vec<FieldInfo>) {
let catalog = crate::control::planner::catalog_adapter::OriginCatalog::new(
Arc::clone(&self.state.credentials),
tenant_id,
crate::types::DatabaseId::DEFAULT,
Some(Arc::clone(&self.state.retention_policy_registry)),
);
let param_count = count_placeholders(sql);
let mut param_types = vec![None; param_count.max(client_types.len())];
for (i, ct) in client_types.iter().enumerate() {
if let Some(t) = ct {
param_types[i] = Some(t.clone());
}
}
let (sql_stripped, returning_spec) =
match crate::control::server::pgwire::handler::returning::strip_returning(sql) {
Ok(pair) => pair,
Err(_) => return (param_types, Vec::new()),
};
let sql_for_inference = substitute_placeholders_with_null(&sql_stripped);
let plans = match nodedb_sql::plan_sql(&sql_for_inference, &catalog) {
Ok(p) => p,
Err(_) => return (param_types, Vec::new()),
};
if let Some(spec) = returning_spec
&& let Some(fields) = result_fields_for_returning(&spec, plans.first(), &catalog)
{
return (param_types, fields);
}
let result_fields = if let Some(projection) = parse_select_projection(&sql_for_inference) {
fields_from_projection(&projection, plans.first(), &catalog)
} else if let Some(plan) = plans.first() {
infer_result_fields(plan, &catalog)
} else {
Vec::new()
};
(param_types, result_fields)
}
}
#[async_trait]
impl QueryParser for NodeDbQueryParser {
type Statement = ParsedStatement;
async fn parse_sql<C>(
&self,
client: &C,
sql: &str,
types: &[Option<Type>],
) -> PgWireResult<Self::Statement>
where
C: ClientInfo + Unpin + Send + Sync,
{
if crate::control::backup::detect(sql).is_some() {
return Ok(ParsedStatement {
sql: sql.to_owned(),
param_types: Vec::new(),
result_fields: Vec::new(),
is_dsl: false,
pg_catalog_table: None,
});
}
let upper = sql.to_uppercase();
if let Some(table) =
crate::control::server::pgwire::pg_catalog::extract_pg_catalog_table(&upper)
{
let result_fields =
crate::control::server::pgwire::pg_catalog::pg_catalog_schema(table)
.unwrap_or_default();
let count = count_placeholders(sql).max(types.len());
let param_types: Vec<Option<Type>> = (0..count)
.map(|i| types.get(i).and_then(|t| t.clone()))
.collect();
return Ok(ParsedStatement {
sql: sql.to_owned(),
param_types,
result_fields,
is_dsl: false,
pg_catalog_table: Some(table),
});
}
let tenant_id = client
.metadata()
.get("user")
.and_then(|u| {
self.state
.credentials
.to_identity(u, crate::control::security::identity::AuthMethod::Trust)
.or_else(|| {
self.state.credentials.to_identity(
u,
crate::control::security::identity::AuthMethod::ScramSha256,
)
})
})
.map(|id| id.tenant_id.as_u64())
.unwrap_or(1);
let (param_types, result_fields) = self.try_infer_types(sql, types, tenant_id);
let is_dsl = result_fields.is_empty() && is_dsl_statement(sql);
Ok(ParsedStatement {
sql: sql.to_owned(),
param_types,
result_fields,
is_dsl,
pg_catalog_table: None,
})
}
fn get_parameter_types(&self, stmt: &Self::Statement) -> PgWireResult<Vec<Type>> {
Ok(stmt
.param_types
.iter()
.map(|t| t.clone().unwrap_or(Type::UNKNOWN))
.collect())
}
fn get_result_schema(
&self,
stmt: &Self::Statement,
_column_format: Option<&pgwire::api::portal::Format>,
) -> PgWireResult<Vec<FieldInfo>> {
Ok(stmt.result_fields.clone())
}
}