use crate::algebra::{AggregateType, CompareOp, Expr, JoinCondition, Operand, Predicate, SortOrder};
use crate::backends::Backend;
use crate::schema::{Column, DataType, ResultSet, Row, Schema, Value};
use crate::{RealError, Result};
pub struct PostgresBackend;
#[derive(Debug, Clone)]
pub struct PostgresQuery {
pub sql: String,
pub params: Vec<Value>,
pub result_schema: Schema,
}
impl PostgresBackend {
pub fn new() -> Self {
Self
}
fn compile_predicate(&self, pred: &Predicate, params: &mut Vec<Value>) -> Result<String> {
match pred {
Predicate::Compare { left, op, right } => {
let left_sql = format!("{}", left.name);
let op_sql = match op {
CompareOp::Eq => "=",
CompareOp::NotEq => "!=",
CompareOp::Lt => "<",
CompareOp::Lte => "<=",
CompareOp::Gt => ">",
CompareOp::Gte => ">=",
};
let right_sql = match right {
Operand::Column(col) => col.name.clone(),
Operand::Literal(val) => {
params.push(val.clone());
format!("${}", params.len())
}
};
Ok(format!("{} {} {}", left_sql, op_sql, right_sql))
}
Predicate::And(left, right) => {
let left_sql = self.compile_predicate(left, params)?;
let right_sql = self.compile_predicate(right, params)?;
Ok(format!("({} AND {})", left_sql, right_sql))
}
Predicate::Or(left, right) => {
let left_sql = self.compile_predicate(left, params)?;
let right_sql = self.compile_predicate(right, params)?;
Ok(format!("({} OR {})", left_sql, right_sql))
}
Predicate::Not(inner) => {
let inner_sql = self.compile_predicate(inner, params)?;
Ok(format!("NOT ({})", inner_sql))
}
Predicate::In { column, values } => {
let placeholders = values
.iter()
.map(|v| {
params.push(v.clone());
format!("${}", params.len())
})
.collect::<Vec<_>>()
.join(", ");
Ok(format!("{} IN ({})", column.name, placeholders))
}
Predicate::Like { column, pattern } => {
params.push(Value::String(pattern.clone()));
Ok(format!("{} LIKE ${}", column.name, params.len()))
}
Predicate::IsNull(column) => Ok(format!("{} IS NULL", column.name)),
Predicate::Between { column, low, high } => {
params.push(low.clone());
let low_idx = params.len();
params.push(high.clone());
let high_idx = params.len();
Ok(format!(
"{} BETWEEN ${} AND ${}",
column.name, low_idx, high_idx
))
}
}
}
}
impl Default for PostgresBackend {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "backend-postgres")]
impl Backend for PostgresBackend {
type Connection = postgres::Client;
type CompiledQuery = PostgresQuery;
fn compile(&self, expr: &Expr) -> Result<Self::CompiledQuery> {
let mut params = Vec::new();
let query = match expr {
Expr::Relation { name, schema } => PostgresQuery {
sql: format!("SELECT * FROM {}", name),
params: vec![],
result_schema: schema.clone(),
},
Expr::Select { input, predicate } => {
let inner = self.compile(input)?;
let where_clause = self.compile_predicate(predicate, &mut params)?;
PostgresQuery {
sql: format!("SELECT * FROM ({}) AS t WHERE {}", inner.sql, where_clause),
params,
result_schema: inner.result_schema,
}
}
Expr::Project { input, columns } => {
let inner = self.compile(input)?;
let cols = columns.join(", ");
PostgresQuery {
sql: format!("SELECT {} FROM ({}) AS t", cols, inner.sql),
params: inner.params,
result_schema: expr.infer_schema(),
}
}
Expr::Join { left, right, condition } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
let join_clause = match condition {
JoinCondition::Using(cols) => format!("USING ({})", cols.join(", ")),
JoinCondition::On(pred) => {
format!("ON {}", self.compile_predicate(pred, &mut params)?)
}
};
params.extend(left_query.params);
params.extend(right_query.params);
PostgresQuery {
sql: format!(
"SELECT * FROM ({}) AS l JOIN ({}) AS r {}",
left_query.sql, right_query.sql, join_clause
),
params,
result_schema: Schema::new("join_result"),
}
}
Expr::Aggregate {
input,
group_by,
aggregates,
} => {
let inner = self.compile(input)?;
let agg_exprs: Vec<String> = aggregates
.iter()
.map(|agg| {
let func_name = match agg.func {
AggregateType::Count => "COUNT",
AggregateType::Sum => "SUM",
AggregateType::Avg => "AVG",
AggregateType::Min => "MIN",
AggregateType::Max => "MAX",
};
format!("{}({}) AS {}", func_name, agg.input, agg.name)
})
.collect();
let select_list = if group_by.is_empty() {
agg_exprs.join(", ")
} else {
format!("{}, {}", group_by.join(", "), agg_exprs.join(", "))
};
let sql = if group_by.is_empty() {
format!("SELECT {} FROM ({}) AS t", select_list, inner.sql)
} else {
format!(
"SELECT {} FROM ({}) AS t GROUP BY {}",
select_list,
inner.sql,
group_by.join(", ")
)
};
PostgresQuery {
sql,
params: inner.params,
result_schema: expr.infer_schema(),
}
}
Expr::Union { left, right } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
params.extend(left_query.params);
params.extend(right_query.params);
PostgresQuery {
sql: format!("({}) UNION ({})", left_query.sql, right_query.sql),
params,
result_schema: left_query.result_schema,
}
}
Expr::Intersect { left, right } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
params.extend(left_query.params);
params.extend(right_query.params);
PostgresQuery {
sql: format!("({}) INTERSECT ({})", left_query.sql, right_query.sql),
params,
result_schema: left_query.result_schema,
}
}
Expr::Difference { left, right } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
params.extend(left_query.params);
params.extend(right_query.params);
PostgresQuery {
sql: format!("({}) EXCEPT ({})", left_query.sql, right_query.sql),
params,
result_schema: left_query.result_schema,
}
}
Expr::Rename { input, from, to } => {
let inner = self.compile(input)?;
PostgresQuery {
sql: format!("SELECT * FROM ({}) AS t", inner.sql),
params: inner.params,
result_schema: expr.infer_schema(),
}
}
Expr::Sort { input, columns } => {
let inner = self.compile(input)?;
let order_by = columns
.iter()
.map(|(col, order)| {
let order_str = match order {
SortOrder::Asc => "ASC",
SortOrder::Desc => "DESC",
};
format!("{} {}", col, order_str)
})
.collect::<Vec<_>>()
.join(", ");
PostgresQuery {
sql: format!("SELECT * FROM ({}) AS t ORDER BY {}", inner.sql, order_by),
params: inner.params,
result_schema: inner.result_schema,
}
}
Expr::Limit { input, count } => {
let inner = self.compile(input)?;
PostgresQuery {
sql: format!("SELECT * FROM ({}) AS t LIMIT {}", inner.sql, count),
params: inner.params,
result_schema: inner.result_schema,
}
}
Expr::Offset { input, count } => {
let inner = self.compile(input)?;
PostgresQuery {
sql: format!("SELECT * FROM ({}) AS t OFFSET {}", inner.sql, count),
params: inner.params,
result_schema: inner.result_schema,
}
}
_ => {
return Err(RealError::Backend(format!(
"PostgreSQL backend does not yet support: {:?}",
expr
)))
}
};
Ok(query)
}
fn execute(
&self,
conn: &mut Self::Connection,
query: &Self::CompiledQuery,
) -> Result<ResultSet> {
let params: Vec<&(dyn postgres::types::ToSql + Sync)> = query
.params
.iter()
.map(|v| v as &(dyn postgres::types::ToSql + Sync))
.collect();
let rows = conn
.query(&query.sql, ¶ms[..])
.map_err(|e| RealError::Backend(e.to_string()))?;
let mut result_set = Vec::new();
for row in rows {
let mut values = Vec::new();
for i in 0..row.len() {
let value = if let Ok(v) = row.try_get::<_, i64>(i) {
Value::Integer(v)
} else if let Ok(v) = row.try_get::<_, f64>(i) {
Value::Float(v)
} else if let Ok(v) = row.try_get::<_, String>(i) {
Value::String(v)
} else if let Ok(v) = row.try_get::<_, bool>(i) {
Value::Boolean(v)
} else if let Ok(v) = row.try_get::<_, Vec<u8>>(i) {
Value::Bytes(v)
} else {
Value::Null
};
values.push(value);
}
result_set.push(values);
}
Ok(result_set)
}
fn get_schema(&self, conn: &mut Self::Connection, relation: &str) -> Result<Schema> {
let query = "SELECT column_name, data_type, is_nullable
FROM information_schema.columns
WHERE table_name = $1
ORDER BY ordinal_position";
let rows = conn
.query(query, &[&relation])
.map_err(|e| RealError::Backend(e.to_string()))?;
let mut schema = Schema::new(relation);
for row in rows {
let name: String = row.get(0);
let type_name: String = row.get(1);
let nullable: String = row.get(2);
let data_type = match type_name.to_lowercase().as_str() {
"integer" | "int" | "int4" | "smallint" | "bigint" => DataType::Integer,
"real" | "float4" | "double precision" | "float8" | "numeric" | "decimal" => {
DataType::Float
}
"text" | "varchar" | "character varying" | "char" | "character" => {
DataType::String
}
"bytea" => DataType::Bytes,
"boolean" | "bool" => DataType::Boolean,
"timestamp" | "timestamptz" | "date" | "time" => DataType::Timestamp,
"json" | "jsonb" => DataType::Json,
_ if type_name.starts_with("_") => {
DataType::Array(Box::new(DataType::String))
}
_ => DataType::String, };
schema.columns.push(Column {
name,
data_type,
nullable: nullable.to_lowercase() == "yes",
});
}
Ok(schema)
}
}
#[cfg(not(feature = "backend-postgres"))]
impl Backend for PostgresBackend {
type Connection = ();
type CompiledQuery = PostgresQuery;
fn compile(&self, _expr: &Expr) -> Result<Self::CompiledQuery> {
Err(RealError::Backend(
"PostgreSQL backend not enabled. Enable 'backend-postgres' feature.".into(),
))
}
fn execute(
&self,
_conn: &mut Self::Connection,
_query: &Self::CompiledQuery,
) -> Result<ResultSet> {
Err(RealError::Backend(
"PostgreSQL backend not enabled. Enable 'backend-postgres' feature.".into(),
))
}
fn get_schema(&self, _conn: &mut Self::Connection, _relation: &str) -> Result<Schema> {
Err(RealError::Backend(
"PostgreSQL backend not enabled. Enable 'backend-postgres' feature.".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::{ColumnRef, Predicate};
#[test]
fn test_postgres_compile_relation() {
let backend = PostgresBackend::new();
let schema = Schema::new("users").with_column("id", DataType::Integer);
let expr = Expr::relation("users", schema);
let query = backend.compile(&expr).unwrap();
assert_eq!(query.sql, "SELECT * FROM users");
}
#[test]
fn test_postgres_compile_select() {
let backend = PostgresBackend::new();
let schema = Schema::new("users")
.with_column("id", DataType::Integer)
.with_column("age", DataType::Integer);
let expr = Expr::relation("users", schema).select(Predicate::Compare {
left: ColumnRef::new("age"),
op: CompareOp::Gt,
right: Operand::Literal(Value::Integer(25)),
});
let query = backend.compile(&expr).unwrap();
assert!(query.sql.contains("WHERE"));
assert!(query.sql.contains("age >"));
assert_eq!(query.params.len(), 1);
}
#[test]
fn test_postgres_compile_aggregate() {
let backend = PostgresBackend::new();
let schema = Schema::new("sales")
.with_column("region", DataType::String)
.with_column("amount", DataType::Float);
let expr = Expr::Aggregate {
input: Box::new(Expr::relation("sales", schema)),
group_by: vec!["region".to_string()],
aggregates: vec![crate::algebra::AggregateFunc {
name: "total".to_string(),
func: AggregateType::Sum,
input: "amount".to_string(),
}],
};
let query = backend.compile(&expr).unwrap();
assert!(query.sql.contains("GROUP BY"));
assert!(query.sql.contains("SUM(amount)"));
}
}