use super::ast::*;
use crate::node::Node;
use crate::storage::memtable::MemTable;
use crate::VectorType;
use std::collections::HashMap;
pub type QueryResult<T> = Vec<HashMap<String, Node<T>>>;
pub fn execute<T: VectorType>(query: &Query, memtable: &MemTable<T>) -> QueryResult<T> {
let pattern = &query.pattern;
let first_node_pat = &pattern.nodes[0];
let start_candidates = find_candidates(first_node_pat, memtable);
let mut bindings_set: Vec<HashMap<String, Node<T>>> = Vec::new();
for start_node in start_candidates {
let mut binding: HashMap<String, Node<T>> = HashMap::new();
if let Some(var) = &first_node_pat.var {
binding.insert(var.clone(), start_node.clone());
}
bindings_set.push(binding);
}
for i in 0..pattern.edges.len() {
let edge_pat = &pattern.edges[i];
let next_node_pat = &pattern.nodes[i + 1];
let mut next_bindings = Vec::new();
for binding in &bindings_set {
let current_node_pat = &pattern.nodes[i];
let current_node = if let Some(var) = ¤t_node_pat.var {
binding.get(var)
} else {
None
};
let current_id = match current_node {
Some(n) => n.id,
None => continue,
};
if let Some(edges) = memtable.get_edges(current_id) {
for edge in edges {
if let Some(ref label) = edge_pat.label {
if &edge.label != label {
continue;
}
}
let target_id = edge.target_id;
let target_node = match build_node(target_id, memtable) {
Some(n) => n,
None => continue,
};
if !matches_node_props(&target_node, next_node_pat) {
continue;
}
let mut new_binding = binding.clone();
if let Some(var) = &next_node_pat.var {
new_binding.insert(var.clone(), target_node);
}
next_bindings.push(new_binding);
}
}
}
bindings_set = next_bindings;
}
if let Some(ref condition) = query.where_clause {
bindings_set.retain(|binding| eval_condition(condition, binding));
}
let return_vars = &query.return_vars;
bindings_set.iter().map(|binding: &HashMap<String, Node<T>>| {
let mut filtered: HashMap<String, Node<T>> = HashMap::new();
for var in return_vars {
if let Some(node) = binding.get(var) {
filtered.insert(var.clone(), node.clone());
}
}
filtered
}).collect()
}
fn find_candidates<T: VectorType>(node_pat: &NodePattern, memtable: &MemTable<T>) -> Vec<Node<T>> {
let mut candidates = Vec::new();
let exact_id = node_pat.props.iter().find(|p| p.key == "id").and_then(|p| {
if let LitValue::Int(tid) = &p.value { Some(*tid as u64) } else { None }
});
if let Some(id) = exact_id {
if let Some(node) = build_node(id, memtable) {
if matches_node_props(&node, node_pat) {
candidates.push(node);
}
}
return candidates;
}
let all_ids = memtable.all_node_ids();
for id in all_ids {
if let Some(node) = build_node(id, memtable) {
if matches_node_props(&node, node_pat) {
candidates.push(node);
}
}
}
candidates
}
fn build_node<T: VectorType>(id: u64, memtable: &MemTable<T>) -> Option<Node<T>> {
let vector = memtable.get_vector(id)?;
let payload = memtable.get_payload(id)?;
let edges = memtable.get_edges(id).map(|e| e.to_vec()).unwrap_or_default();
Some(Node {
id,
vector: vector.to_vec(),
payload: payload.clone(),
edges,
})
}
fn matches_node_props<T: VectorType>(node: &Node<T>, pat: &NodePattern) -> bool {
for prop in &pat.props {
match prop.key.as_str() {
"id" => {
if let LitValue::Int(target_id) = &prop.value {
if node.id != *target_id as u64 {
return false;
}
}
}
field => {
let json_val = &node.payload[field];
if !lit_matches_json(&prop.value, json_val) {
return false;
}
}
}
}
true
}
fn lit_matches_json(lit: &LitValue, json: &serde_json::Value) -> bool {
match lit {
LitValue::Int(n) => json.as_i64() == Some(*n),
LitValue::Float(f) => json.as_f64() == Some(*f),
LitValue::Str(s) => json.as_str() == Some(s),
LitValue::Bool(b) => json.as_bool() == Some(*b),
}
}
fn eval_condition<T: VectorType>(cond: &Condition, binding: &HashMap<String, Node<T>>) -> bool {
match cond {
Condition::Compare { left, op, right } => {
let lval = eval_expr(left, binding);
let rval = eval_expr(right, binding);
compare_values(&lval, op, &rval)
}
Condition::And(a, b) => {
eval_condition(a, binding) && eval_condition(b, binding)
}
Condition::Or(a, b) => {
eval_condition(a, binding) || eval_condition(b, binding)
}
}
}
fn eval_expr<T: VectorType>(expr: &Expr, binding: &HashMap<String, Node<T>>) -> RuntimeValue {
match expr {
Expr::Property { var, field } => {
if let Some(node) = binding.get(var) {
if field == "id" {
return RuntimeValue::Int(node.id as i64);
}
json_to_runtime(&node.payload[field])
} else {
RuntimeValue::Null
}
}
Expr::Literal(lit) => lit_to_runtime(lit),
}
}
#[derive(Debug, Clone)]
enum RuntimeValue {
Int(i64),
Float(f64),
Str(String),
Bool(bool),
Null,
}
fn json_to_runtime(v: &serde_json::Value) -> RuntimeValue {
match v {
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
RuntimeValue::Int(i)
} else {
RuntimeValue::Float(n.as_f64().unwrap_or(0.0))
}
}
serde_json::Value::String(s) => RuntimeValue::Str(s.clone()),
serde_json::Value::Bool(b) => RuntimeValue::Bool(*b),
_ => RuntimeValue::Null,
}
}
fn lit_to_runtime(lit: &LitValue) -> RuntimeValue {
match lit {
LitValue::Int(n) => RuntimeValue::Int(*n),
LitValue::Float(f) => RuntimeValue::Float(*f),
LitValue::Str(s) => RuntimeValue::Str(s.clone()),
LitValue::Bool(b) => RuntimeValue::Bool(*b),
}
}
fn compare_values(lhs: &RuntimeValue, op: &CompOp, rhs: &RuntimeValue) -> bool {
match (lhs, rhs) {
(RuntimeValue::Int(a), RuntimeValue::Int(b)) => cmp_ord(a, op, b),
(RuntimeValue::Float(a), RuntimeValue::Float(b)) => cmp_f64(*a, op, *b),
(RuntimeValue::Int(a), RuntimeValue::Float(b)) => cmp_f64(*a as f64, op, *b),
(RuntimeValue::Float(a), RuntimeValue::Int(b)) => cmp_f64(*a, op, *b as f64),
(RuntimeValue::Str(a), RuntimeValue::Str(b)) => cmp_ord(a, op, b),
(RuntimeValue::Bool(a), RuntimeValue::Bool(b)) => {
match op {
CompOp::Eq => a == b,
CompOp::Ne => a != b,
_ => false,
}
}
_ => false,
}
}
fn cmp_ord<T: Ord>(a: &T, op: &CompOp, b: &T) -> bool {
match op {
CompOp::Eq => a == b,
CompOp::Ne => a != b,
CompOp::Gt => a > b,
CompOp::Gte => a >= b,
CompOp::Lt => a < b,
CompOp::Lte => a <= b,
}
}
fn cmp_f64(a: f64, op: &CompOp, b: f64) -> bool {
match op {
CompOp::Eq => (a - b).abs() < f64::EPSILON,
CompOp::Ne => (a - b).abs() >= f64::EPSILON,
CompOp::Gt => a > b,
CompOp::Gte => a >= b,
CompOp::Lt => a < b,
CompOp::Lte => a <= b,
}
}