use std::sync::Arc;
use futures::StreamExt;
use pgwire::api::Type;
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use super::super::types::text_field;
pub(super) enum ProjectionItem {
Star,
Named {
lookup_key: String,
display_name: String,
},
}
pub(super) fn parse_select_projection(sql: &str) -> Option<Vec<ProjectionItem>> {
use sqlparser::ast::{SelectItem, SetExpr, Statement};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
let stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql).ok()?;
let stmt = stmts.into_iter().next()?;
let Statement::Query(query) = stmt else {
return None;
};
let SetExpr::Select(select) = *query.body else {
return None;
};
let mut out = Vec::with_capacity(select.projection.len());
for item in &select.projection {
match item {
SelectItem::Wildcard(_) | SelectItem::QualifiedWildcard(..) => {
out.push(ProjectionItem::Star);
}
SelectItem::UnnamedExpr(expr) => {
let (lookup_key, display_name) = expr_column_names(expr);
out.push(ProjectionItem::Named {
lookup_key,
display_name,
});
}
SelectItem::ExprWithAlias { expr, alias } => {
let (lookup_key, _) = expr_column_names(expr);
out.push(ProjectionItem::Named {
lookup_key,
display_name: alias.value.clone(),
});
}
}
}
Some(out)
}
fn expr_column_names(expr: &sqlparser::ast::Expr) -> (String, String) {
use sqlparser::ast::Expr;
match expr {
Expr::Identifier(id) => {
let name = id.value.clone();
(name.clone(), name)
}
Expr::CompoundIdentifier(parts) => {
let lookup_key = parts
.iter()
.map(|p| p.value.as_str())
.collect::<Vec<_>>()
.join(".");
let display_name = parts
.last()
.map(|p| p.value.clone())
.unwrap_or_else(|| lookup_key.clone());
(lookup_key, display_name)
}
other => {
let s = other.to_string().to_lowercase();
(s.clone(), s)
}
}
}
pub(super) fn needs_projection(items: &[ProjectionItem]) -> bool {
items
.iter()
.any(|i| matches!(i, ProjectionItem::Named { .. }))
}
pub(super) fn fields_for_projection(items: &[ProjectionItem]) -> Vec<FieldInfo> {
items
.iter()
.filter_map(|item| match item {
ProjectionItem::Named { display_name, .. } => Some(FieldInfo::new(
display_name.clone(),
None,
None,
Type::TEXT,
FieldFormat::Text,
)),
ProjectionItem::Star => None,
})
.collect()
}
pub(super) fn lookup_keys_for_projection(items: &[ProjectionItem]) -> Vec<String> {
items
.iter()
.filter_map(|item| match item {
ProjectionItem::Named { lookup_key, .. } => Some(lookup_key.clone()),
ProjectionItem::Star => None,
})
.collect()
}
pub(super) async fn reproject_response(
response: Response,
result_fields: &[FieldInfo],
lookup_keys: &[String],
) -> PgWireResult<Response> {
let qr = match response {
Response::Query(qr) => qr,
other => return Ok(other),
};
let schema = Arc::new(result_fields.to_vec());
let flat_rows = collect_flat_rows(qr).await?;
let mut pgwire_rows = Vec::with_capacity(flat_rows.len());
for obj in &flat_rows {
let mut encoder = DataRowEncoder::new(schema.clone());
for (i, lookup_key) in lookup_keys.iter().enumerate() {
let bare = lookup_key
.rfind('.')
.map(|i| &lookup_key[i + 1..])
.unwrap_or(lookup_key.as_str());
let display_name: Option<&str> = result_fields.get(i).map(|f| f.name());
let value = obj
.get(lookup_key.as_str())
.or_else(|| {
if bare != lookup_key {
obj.get(bare)
} else {
None
}
})
.or_else(|| {
display_name.and_then(|n| {
if n != lookup_key.as_str() && Some(n) != Some(bare) {
obj.get(n)
} else {
None
}
})
});
match value {
None | Some(serde_json::Value::Null) => {
let _ = encoder.encode_field(&Option::<String>::None);
}
Some(v) => {
let text = match v {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
let _ = encoder.encode_field(&text);
}
}
}
pgwire_rows.push(Ok(encoder.take_row()));
}
Ok(Response::Query(QueryResponse::new(
schema,
futures::stream::iter(pgwire_rows),
)))
}
pub(super) async fn collect_flat_rows(
mut qr: QueryResponse,
) -> PgWireResult<Vec<serde_json::Map<String, serde_json::Value>>> {
let mut rows = Vec::new();
while let Some(row_result) = qr.data_rows.next().await {
let row = row_result?;
let Some(text) = decode_first_field_text(&row.data) else {
continue;
};
let value = sonic_rs::from_str::<serde_json::Value>(text).map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("malformed Data-Plane response envelope: {e}"),
)))
})?;
push_flat_rows(value, &mut rows);
}
Ok(rows)
}
pub(super) fn push_flat_rows(
value: serde_json::Value,
out: &mut Vec<serde_json::Map<String, serde_json::Value>>,
) {
match value {
serde_json::Value::Array(items) => {
for item in items {
push_flat_rows(item, out);
}
}
serde_json::Value::Object(mut map) => {
if is_scan_wrapper(&map)
&& let Some(serde_json::Value::Object(inner)) = map.remove("data")
{
out.push(inner);
return;
}
out.push(map);
}
_ => {}
}
}
pub(super) fn is_scan_wrapper(map: &serde_json::Map<String, serde_json::Value>) -> bool {
map.len() == 2
&& matches!(map.get("id"), Some(serde_json::Value::String(_)))
&& matches!(map.get("data"), Some(serde_json::Value::Object(_)))
}
pub(super) async fn reproject_star_response(response: Response) -> PgWireResult<Response> {
let qr = match response {
Response::Query(qr) => qr,
other => return Ok(other),
};
let flat_rows = collect_flat_rows(qr).await?;
if flat_rows.is_empty() {
let schema = Arc::new(vec![text_field("result")]);
return Ok(Response::Query(QueryResponse::new(
schema,
futures::stream::iter(Vec::<PgWireResult<_>>::new()),
)));
}
let mut cols: Vec<String> = Vec::new();
let first = &flat_rows[0];
if first.contains_key("id") {
cols.push("id".to_string());
}
for key in first.keys() {
if key != "id" {
cols.push(key.clone());
}
}
for row in flat_rows.iter().skip(1) {
for key in row.keys() {
if !cols.contains(key) {
cols.push(key.clone());
}
}
}
let schema: Arc<Vec<_>> = Arc::new(cols.iter().map(|c| text_field(c)).collect());
let row_schema = schema.clone();
let pgwire_rows: Vec<PgWireResult<_>> = flat_rows
.iter()
.map(|obj| {
let mut encoder = DataRowEncoder::new(row_schema.clone());
for col in &cols {
match obj.get(col.as_str()) {
None | Some(serde_json::Value::Null) => {
let _ = encoder.encode_field(&Option::<String>::None);
}
Some(v) => {
let text = match v {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
let _ = encoder.encode_field(&text);
}
}
}
Ok(encoder.take_row())
})
.collect();
Ok(Response::Query(QueryResponse::new(
schema,
futures::stream::iter(pgwire_rows),
)))
}
pub(super) fn decode_first_field_text(data: &bytes::BytesMut) -> Option<&str> {
if data.len() < 4 {
return None;
}
let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
if len < 0 {
return None;
}
let len = len as usize;
if data.len() < 4 + len {
return None;
}
std::str::from_utf8(&data[4..4 + len]).ok()
}