use crate::algebra::{CompareOp, Expr, Operand, Predicate};
use crate::backends::Backend;
use crate::schema::{DataType, ResultSet, Row, Schema, Value};
use crate::{RealError, Result};
fn base64_encode(bytes: &[u8]) -> String {
const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
for chunk in bytes.chunks(3) {
let b1 = chunk[0];
let b2 = chunk.get(1).copied().unwrap_or(0);
let b3 = chunk.get(2).copied().unwrap_or(0);
result.push(BASE64_CHARS[(b1 >> 2) as usize] as char);
result.push(BASE64_CHARS[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
result.push(if chunk.len() > 1 {
BASE64_CHARS[(((b2 & 0x0f) << 2) | (b3 >> 6)) as usize] as char
} else {
'='
});
result.push(if chunk.len() > 2 {
BASE64_CHARS[(b3 & 0x3f) as usize] as char
} else {
'='
});
}
result
}
pub struct MongoDBBackend;
#[derive(Debug, Clone)]
pub struct MongoQuery {
pub collection: String,
pub pipeline: Vec<MongoStage>,
pub result_schema: Schema,
}
#[derive(Debug, Clone)]
pub enum MongoStage {
Match(String), Project(String), Lookup(String), Group(String), Sort(String), Limit(usize), Skip(usize), }
impl MongoDBBackend {
pub fn new() -> Self {
Self
}
fn compile_predicate(&self, pred: &Predicate) -> Result<String> {
match pred {
Predicate::Compare { left, op, right } => {
let field = &left.name;
let (mongo_op, value) = match (op, right) {
(CompareOp::Eq, Operand::Literal(v)) => ("$eq", self.value_to_json(v)),
(CompareOp::NotEq, Operand::Literal(v)) => ("$ne", self.value_to_json(v)),
(CompareOp::Lt, Operand::Literal(v)) => ("$lt", self.value_to_json(v)),
(CompareOp::Lte, Operand::Literal(v)) => ("$lte", self.value_to_json(v)),
(CompareOp::Gt, Operand::Literal(v)) => ("$gt", self.value_to_json(v)),
(CompareOp::Gte, Operand::Literal(v)) => ("$gte", self.value_to_json(v)),
(op, Operand::Column(col)) => {
return Ok(format!(
r#"{{ "{}": {{ "{}": "${}" }} }}"#,
field,
match op {
CompareOp::Eq => "$eq",
CompareOp::NotEq => "$ne",
CompareOp::Lt => "$lt",
CompareOp::Lte => "$lte",
CompareOp::Gt => "$gt",
CompareOp::Gte => "$gte",
},
col.name
));
}
};
Ok(format!(r#"{{ "{}": {{ "{}": {} }} }}"#, field, mongo_op, value))
}
Predicate::And(left, right) => {
let left_json = self.compile_predicate(left)?;
let right_json = self.compile_predicate(right)?;
Ok(format!(r#"{{ "$and": [{}, {}] }}"#, left_json, right_json))
}
Predicate::Or(left, right) => {
let left_json = self.compile_predicate(left)?;
let right_json = self.compile_predicate(right)?;
Ok(format!(r#"{{ "$or": [{}, {}] }}"#, left_json, right_json))
}
Predicate::Not(inner) => {
let inner_json = self.compile_predicate(inner)?;
Ok(format!(r#"{{ "$not": {} }}"#, inner_json))
}
Predicate::In { column, values } => {
let values_json = values
.iter()
.map(|v| self.value_to_json(v))
.collect::<Vec<_>>()
.join(", ");
Ok(format!(r#"{{ "{}": {{ "$in": [{}] }} }}"#, column.name, values_json))
}
Predicate::Like { column, pattern } => {
let regex = pattern.replace('%', ".*").replace('_', ".");
Ok(format!(r#"{{ "{}": {{ "$regex": "{}" }} }}"#, column.name, regex))
}
Predicate::IsNull(column) => {
Ok(format!(r#"{{ "{}": null }}"#, column.name))
}
Predicate::Between { column, low, high } => {
let low_json = self.value_to_json(low);
let high_json = self.value_to_json(high);
Ok(format!(
r#"{{ "{}": {{ "$gte": {}, "$lte": {} }} }}"#,
column.name, low_json, high_json
))
}
}
}
fn value_to_json(&self, value: &Value) -> String {
match value {
Value::Integer(i) => i.to_string(),
Value::Float(f) => f.to_string(),
Value::String(s) => format!(r#""{}""#, s),
Value::Boolean(b) => b.to_string(),
Value::Null => "null".to_string(),
Value::Bytes(b) => {
let base64 = base64_encode(b);
format!(r#"{{"$binary": {{"base64": "{}", "subType": "00"}}}}"#, base64)
}
Value::Timestamp(ts) => {
format!(r#"{{"$date": {}}}"#, ts)
}
Value::Decimal(d) => {
format!(r#"{{"$numberDecimal": "{}"}}"#, d)
}
Value::Json(j) => {
j.clone()
}
Value::Array(arr) => {
let items: Vec<String> = arr.iter().map(|v| self.value_to_json(v)).collect();
format!("[{}]", items.join(", "))
}
Value::Vector(v) => {
let items: Vec<String> = v.iter().map(|f| f.to_string()).collect();
format!("[{}]", items.join(", "))
}
}
}
}
impl Default for MongoDBBackend {
fn default() -> Self {
Self::new()
}
}
impl Backend for MongoDBBackend {
type Connection = (); type CompiledQuery = MongoQuery;
fn compile(&self, expr: &Expr) -> Result<Self::CompiledQuery> {
match expr {
Expr::Relation { name, schema } => Ok(MongoQuery {
collection: name.clone(),
pipeline: vec![],
result_schema: schema.clone(),
}),
Expr::Select { input, predicate } => {
let mut query = self.compile(input)?;
let match_stage = self.compile_predicate(predicate)?;
query.pipeline.push(MongoStage::Match(match_stage));
Ok(query)
}
Expr::Project { input, columns } => {
let mut query = self.compile(input)?;
let project_fields = columns
.iter()
.map(|c| format!(r#""{}": 1"#, c))
.collect::<Vec<_>>()
.join(", ");
let project_stage = format!(r#"{{ {} }}"#, project_fields);
query.pipeline.push(MongoStage::Project(project_stage));
query.result_schema = expr.infer_schema();
Ok(query)
}
Expr::Join { left, right, condition } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
let (local_field, foreign_field) = match condition {
crate::algebra::JoinCondition::Using(cols) => {
if !cols.is_empty() {
(cols[0].clone(), cols[0].clone())
} else {
("_id".to_string(), "_id".to_string())
}
}
crate::algebra::JoinCondition::On(_pred) => {
("_id".to_string(), "_id".to_string())
}
};
let lookup = format!(
r#"{{ "from": "{}", "localField": "{}", "foreignField": "{}", "as": "joined" }}"#,
right_query.collection, local_field, foreign_field
);
let mut query = left_query;
query.pipeline.push(MongoStage::Lookup(lookup));
Ok(query)
}
Expr::Aggregate { input, group_by, aggregates } => {
let mut query = self.compile(input)?;
let id_expr = if group_by.is_empty() {
"null".to_string()
} else if group_by.len() == 1 {
format!(r#""${}""#, group_by[0])
} else {
let fields: Vec<String> = group_by
.iter()
.map(|f| format!(r#""{}": "${}"#, f, f))
.collect();
format!(r#"{{ {} }}"#, fields.join(", "))
};
let agg_exprs: Vec<String> = aggregates
.iter()
.map(|agg| {
let mongo_func = match agg.func {
crate::algebra::AggregateType::Count => "$sum: 1".to_string(),
crate::algebra::AggregateType::Sum => format!("$sum: \"${}\"", agg.input),
crate::algebra::AggregateType::Avg => format!("$avg: \"${}\"", agg.input),
crate::algebra::AggregateType::Min => format!("$min: \"${}\"", agg.input),
crate::algebra::AggregateType::Max => format!("$max: \"${}\"", agg.input),
};
format!(r#""{}" {{ {} }}"#, agg.name, mongo_func)
})
.collect();
let group_stage = format!(
r#"{{ "_id": {}, {} }}"#,
id_expr,
agg_exprs.join(", ")
);
query.pipeline.push(MongoStage::Group(group_stage));
query.result_schema = expr.infer_schema();
Ok(query)
}
Expr::Union { left, right } => {
let left_query = self.compile(left)?;
let right_query = self.compile(right)?;
let union_stage = format!(r#"{{ "coll": "{}" }}"#, right_query.collection);
let mut query = left_query;
query.pipeline.push(MongoStage::Match(union_stage));
Ok(query)
}
Expr::Intersect { left, right } => {
Err(RealError::Backend(
"MongoDB INTERSECT requires client-side processing".into(),
))
}
Expr::Difference { left, right } => {
Err(RealError::Backend(
"MongoDB DIFFERENCE requires client-side processing".into(),
))
}
Expr::Rename { input, from, to } => {
let mut query = self.compile(input)?;
let rename_stage = format!(
r#"{{ "{}": "${}", "{}": 0 }}"#,
to, from, from
);
query.pipeline.push(MongoStage::Project(rename_stage));
query.result_schema = expr.infer_schema();
Ok(query)
}
Expr::Sort { input, columns } => {
let mut query = self.compile(input)?;
let sort_fields = columns
.iter()
.map(|(col, order)| {
let order_val = match order {
crate::algebra::SortOrder::Asc => 1,
crate::algebra::SortOrder::Desc => -1,
};
format!(r#""{}": {}"#, col, order_val)
})
.collect::<Vec<_>>()
.join(", ");
let sort_stage = format!(r#"{{ {} }}"#, sort_fields);
query.pipeline.push(MongoStage::Sort(sort_stage));
Ok(query)
}
Expr::Limit { input, count } => {
let mut query = self.compile(input)?;
query.pipeline.push(MongoStage::Limit(*count));
Ok(query)
}
Expr::Offset { input, count } => {
let mut query = self.compile(input)?;
query.pipeline.push(MongoStage::Skip(*count));
Ok(query)
}
_ => Err(RealError::Backend(format!(
"MongoDB backend does not yet support: {:?}",
expr
))),
}
}
fn execute(&self, _conn: &mut Self::Connection, query: &Self::CompiledQuery) -> Result<ResultSet> {
println!("MongoDB Query:");
println!(" Collection: {}", query.collection);
println!(" Pipeline:");
for (i, stage) in query.pipeline.iter().enumerate() {
println!(" Stage {}: {:?}", i, stage);
}
Ok(Vec::new())
}
fn get_schema(&self, _conn: &mut Self::Connection, _relation: &str) -> Result<Schema> {
Err(RealError::Backend(
"MongoDB schema introspection not implemented".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::{ColumnRef, Predicate};
#[test]
fn test_mongodb_compile_relation() {
let backend = MongoDBBackend::new();
let schema = Schema::new("users").with_column("name", DataType::String);
let expr = Expr::relation("users", schema);
let query = backend.compile(&expr).unwrap();
assert_eq!(query.collection, "users");
assert_eq!(query.pipeline.len(), 0);
}
#[test]
fn test_mongodb_compile_select() {
let backend = MongoDBBackend::new();
let schema = Schema::new("users")
.with_column("name", DataType::String)
.with_column("age", DataType::Integer);
let expr = Expr::relation("users", schema).select(Predicate::Compare {
left: ColumnRef::new("age"),
op: CompareOp::Gt,
right: Operand::Literal(Value::Integer(25)),
});
let query = backend.compile(&expr).unwrap();
assert_eq!(query.collection, "users");
assert_eq!(query.pipeline.len(), 1);
if let MongoStage::Match(ref stage) = query.pipeline[0] {
assert!(stage.contains("$gt"));
assert!(stage.contains("25"));
} else {
panic!("Expected Match stage");
}
}
#[test]
fn test_mongodb_compile_project() {
let backend = MongoDBBackend::new();
let schema = Schema::new("users")
.with_column("name", DataType::String)
.with_column("age", DataType::Integer);
let expr = Expr::relation("users", schema).project(vec!["name".to_string()]);
let query = backend.compile(&expr).unwrap();
assert_eq!(query.pipeline.len(), 1);
if let MongoStage::Project(ref stage) = query.pipeline[0] {
assert!(stage.contains("name"));
} else {
panic!("Expected Project stage");
}
}
}