use crate::algebra::{AggregateType, CompareOp, Expr, JoinCondition, Operand, Predicate};
use crate::backends::Backend;
use crate::schema::{Column, DataType, ResultSet, Row, Schema, Value};
use crate::{RealError, Result};
#[cfg(feature = "backend-sqlite")]
use rusqlite::{Connection, params_from_iter};
pub struct SQLiteBackend;
#[derive(Debug, Clone)]
pub struct SQLQuery {
pub sql: String,
pub params: Vec<Value>,
pub result_schema: Schema,
}
impl SQLiteBackend {
pub fn new() -> Self {
Self
}
fn compile_predicate(&self, pred: &Predicate, params: &mut Vec<Value>) -> Result<String> {
match pred {
Predicate::Compare { left, op, right } => {
let left_sql = format!("{}", left.name);
let op_sql = match op {
CompareOp::Eq => "=",
CompareOp::NotEq => "!=",
CompareOp::Lt => "<",
CompareOp::Lte => "<=",
CompareOp::Gt => ">",
CompareOp::Gte => ">=",
};
let right_sql = match right {
Operand::Column(col) => col.name.clone(),
Operand::Literal(val) => {
params.push(val.clone());
"?".to_string()
}
};
Ok(format!("{} {} {}", left_sql, op_sql, right_sql))
}
Predicate::And(left, right) => {
let left_sql = self.compile_predicate(left, params)?;
let right_sql = self.compile_predicate(right, params)?;
Ok(format!("({} AND {})", left_sql, right_sql))
}
Predicate::Or(left, right) => {
let left_sql = self.compile_predicate(left, params)?;
let right_sql = self.compile_predicate(right, params)?;
Ok(format!("({} OR {})", left_sql, right_sql))
}
Predicate::Not(inner) => {
let inner_sql = self.compile_predicate(inner, params)?;
Ok(format!("NOT ({})", inner_sql))
}
Predicate::In { column, values } => {
let placeholders = values.iter().map(|v| {
params.push(v.clone());
"?"
}).collect::<Vec<_>>().join(", ");
Ok(format!("{} IN ({})", column.name, placeholders))
}
Predicate::Like { column, pattern } => {
params.push(Value::String(pattern.clone()));
Ok(format!("{} LIKE ?", column.name))
}
Predicate::IsNull(column) => {
Ok(format!("{} IS NULL", column.name))
}
Predicate::Between { column, low, high } => {
params.push(low.clone());
params.push(high.clone());
Ok(format!("{} BETWEEN ? AND ?", column.name))
}
}
}
}
impl Default for SQLiteBackend {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "backend-sqlite")]
impl Backend for SQLiteBackend {
type Connection = Connection;
type CompiledQuery = SQLQuery;
fn compile(&self, expr: &Expr) -> Result<Self::CompiledQuery> {
let mut params = Vec::new();
let sql = match expr {
Expr::Relation { name, schema } => {
SQLQuery {
sql: format!("SELECT * FROM {}", name),
params: vec![],
result_schema: schema.clone(),
}
}
Expr::Select { input, predicate } => {
let inner = self.compile(input)?;
let where_clause = self.compile_predicate(predicate, &mut params)?;
SQLQuery {
sql: format!("SELECT * FROM ({}) WHERE {}", inner.sql, where_clause),
params,
result_schema: inner.result_schema,
}
}
Expr::Project { input, columns } => {
let inner = self.compile(input)?;
let cols = columns.join(", ");
SQLQuery {
sql: format!("SELECT {} FROM ({})", cols, inner.sql),
params: inner.params,
result_schema: expr.infer_schema(),
}
}
Expr::Join { left, right, condition } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
let join_clause = match condition {
JoinCondition::Using(cols) => format!("USING ({})", cols.join(", ")),
JoinCondition::On(pred) => {
format!("ON {}", self.compile_predicate(pred, &mut params)?)
}
};
params.extend(left_query.params);
params.extend(right_query.params);
SQLQuery {
sql: format!(
"SELECT * FROM ({}) JOIN ({}) {}",
left_query.sql, right_query.sql, join_clause
),
params,
result_schema: Schema::new("join_result"),
}
}
Expr::Aggregate { input, group_by, aggregates } => {
let inner = self.compile(input)?;
let agg_exprs: Vec<String> = aggregates
.iter()
.map(|agg| {
let func_name = match agg.func {
AggregateType::Count => "COUNT",
AggregateType::Sum => "SUM",
AggregateType::Avg => "AVG",
AggregateType::Min => "MIN",
AggregateType::Max => "MAX",
};
format!("{}({}) AS {}", func_name, agg.input, agg.name)
})
.collect();
let select_list = if group_by.is_empty() {
agg_exprs.join(", ")
} else {
format!("{}, {}", group_by.join(", "), agg_exprs.join(", "))
};
let sql = if group_by.is_empty() {
format!("SELECT {} FROM ({})", select_list, inner.sql)
} else {
format!(
"SELECT {} FROM ({}) GROUP BY {}",
select_list,
inner.sql,
group_by.join(", ")
)
};
SQLQuery {
sql,
params: inner.params,
result_schema: expr.infer_schema(),
}
}
Expr::Union { left, right } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
params.extend(left_query.params);
params.extend(right_query.params);
SQLQuery {
sql: format!("({}) UNION ({})", left_query.sql, right_query.sql),
params,
result_schema: left_query.result_schema,
}
}
Expr::Intersect { left, right } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
params.extend(left_query.params);
params.extend(right_query.params);
SQLQuery {
sql: format!("({}) INTERSECT ({})", left_query.sql, right_query.sql),
params,
result_schema: left_query.result_schema,
}
}
Expr::Difference { left, right } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
params.extend(left_query.params);
params.extend(right_query.params);
SQLQuery {
sql: format!("({}) EXCEPT ({})", left_query.sql, right_query.sql),
params,
result_schema: left_query.result_schema,
}
}
Expr::Rename { input, from, to } => {
let inner = self.compile(input)?;
SQLQuery {
sql: format!(
"SELECT * FROM ({}) /* Rename {} to {} */",
inner.sql, from, to
),
params: inner.params,
result_schema: expr.infer_schema(),
}
}
Expr::Sort { input, columns } => {
let inner = self.compile(input)?;
let order_by = columns
.iter()
.map(|(col, order)| {
let order_str = match order {
crate::algebra::SortOrder::Asc => "ASC",
crate::algebra::SortOrder::Desc => "DESC",
};
format!("{} {}", col, order_str)
})
.collect::<Vec<_>>()
.join(", ");
SQLQuery {
sql: format!("SELECT * FROM ({}) ORDER BY {}", inner.sql, order_by),
params: inner.params,
result_schema: inner.result_schema,
}
}
Expr::Limit { input, count } => {
let inner = self.compile(input)?;
SQLQuery {
sql: format!("SELECT * FROM ({}) LIMIT {}", inner.sql, count),
params: inner.params,
result_schema: inner.result_schema,
}
}
Expr::Offset { input, count } => {
let inner = self.compile(input)?;
SQLQuery {
sql: format!("SELECT * FROM ({}) OFFSET {}", inner.sql, count),
params: inner.params,
result_schema: inner.result_schema,
}
}
_ => {
return Err(RealError::Backend(format!(
"SQLite backend does not yet support: {:?}",
expr
)))
}
};
Ok(sql)
}
fn execute(&self, conn: &mut Self::Connection, query: &Self::CompiledQuery) -> Result<ResultSet> {
let mut stmt = conn
.prepare(&query.sql)
.map_err(|e| RealError::Backend(e.to_string()))?;
let column_count = stmt.column_count();
let rows = stmt
.query_map(params_from_iter(query.params.iter()), |row| {
let mut values = Vec::new();
for i in 0..column_count {
let val: rusqlite::types::Value = row.get(i)?;
let value = match val {
rusqlite::types::Value::Null => Value::Null,
rusqlite::types::Value::Integer(i) => Value::Integer(i),
rusqlite::types::Value::Real(f) => Value::Float(f),
rusqlite::types::Value::Text(s) => Value::String(s),
rusqlite::types::Value::Blob(b) => Value::Bytes(b),
};
values.push(value);
}
Ok(values)
})
.map_err(|e| RealError::Backend(e.to_string()))?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| RealError::Backend(e.to_string()))?;
Ok(rows)
}
fn get_schema(&self, conn: &mut Self::Connection, relation: &str) -> Result<Schema> {
let query = format!("PRAGMA table_info({})", relation);
let mut stmt = conn
.prepare(&query)
.map_err(|e| RealError::Backend(e.to_string()))?;
let mut schema = Schema::new(relation);
let rows = stmt
.query_map([], |row| {
let name: String = row.get(1)?;
let type_name: String = row.get(2)?;
let nullable: i32 = row.get(3)?;
let data_type = match type_name.to_uppercase().as_str() {
"INTEGER" | "INT" | "BIGINT" | "SMALLINT" => DataType::Integer,
"REAL" | "FLOAT" | "DOUBLE" | "NUMERIC" => DataType::Float,
"TEXT" | "VARCHAR" | "CHAR" | "CLOB" => DataType::String,
"BLOB" => DataType::Bytes,
"BOOLEAN" | "BOOL" => DataType::Boolean,
"" => DataType::String, _ => {
DataType::String
}
};
Ok(Column {
name,
data_type,
nullable: nullable == 0,
})
})
.map_err(|e| RealError::Backend(e.to_string()))?;
for row in rows {
schema
.columns
.push(row.map_err(|e| RealError::Backend(e.to_string()))?);
}
Ok(schema)
}
}
#[cfg(not(feature = "backend-sqlite"))]
impl Backend for SQLiteBackend {
type Connection = ();
type CompiledQuery = SQLQuery;
fn compile(&self, _expr: &Expr) -> Result<Self::CompiledQuery> {
Err(RealError::Backend(
"SQLite backend not enabled. Enable 'backend-sqlite' feature.".into(),
))
}
fn execute(&self, _conn: &mut Self::Connection, _query: &Self::CompiledQuery) -> Result<ResultSet> {
Err(RealError::Backend(
"SQLite backend not enabled. Enable 'backend-sqlite' feature.".into(),
))
}
fn get_schema(&self, _conn: &mut Self::Connection, _relation: &str) -> Result<Schema> {
Err(RealError::Backend(
"SQLite backend not enabled. Enable 'backend-sqlite' feature.".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::{ColumnRef, Predicate};
#[test]
fn test_sqlite_compile_relation() {
let backend = SQLiteBackend::new();
let schema = Schema::new("users").with_column("id", DataType::Integer);
let expr = Expr::relation("users", schema);
let query = backend.compile(&expr).unwrap();
assert_eq!(query.sql, "SELECT * FROM users");
}
#[test]
fn test_sqlite_compile_select() {
let backend = SQLiteBackend::new();
let schema = Schema::new("users")
.with_column("id", DataType::Integer)
.with_column("age", DataType::Integer);
let expr = Expr::relation("users", schema).select(Predicate::Compare {
left: ColumnRef::new("age"),
op: CompareOp::Gt,
right: Operand::Literal(Value::Integer(25)),
});
let query = backend.compile(&expr).unwrap();
assert!(query.sql.contains("WHERE"));
assert!(query.sql.contains("age >"));
assert_eq!(query.params.len(), 1);
}
}