use super::tql_ast::*;
use crate::VectorType;
use crate::error::TriviumError;
use crate::filter::Filter;
use crate::node::{Node, NodeId};
use crate::storage::memtable::MemTable;
use std::collections::{BTreeMap, HashMap, HashSet};
pub type TqlResult<T> = Vec<HashMap<String, Node<T>>>;
#[derive(Debug, Clone)]
pub struct TqlMutResult {
pub affected: usize,
pub created_ids: Vec<NodeId>,
}
#[derive(Debug, Clone)]
pub enum MutationOp<T: VectorType> {
InsertNode {
var: String,
vector: Vec<T>,
payload: serde_json::Value,
},
LinkEdge {
src_id: NodeId,
dst_id: NodeId,
label: String,
weight: f32,
},
UpdatePayload {
id: NodeId,
payload: serde_json::Value,
},
DeleteNode { id: NodeId, detach: bool },
}
const MAX_BUDGET: usize = 100_000;
const DEFAULT_ROW_LIMIT: usize = 5_000;
pub fn execute_tql<T: VectorType>(
query: &TqlQuery,
memtable: &MemTable<T>,
) -> Result<TqlResult<T>, TriviumError> {
if query.explain {
return Ok(generate_explain_plan(query, memtable));
}
let row_limit = query.limit.unwrap_or(DEFAULT_ROW_LIMIT);
let mut results = match &query.entry {
QueryEntry::Find { filter } => execute_find(filter, query, memtable, row_limit)?,
QueryEntry::Match { pattern } => execute_match(pattern, query, memtable, row_limit, false)?,
QueryEntry::OptionalMatch { pattern } => {
execute_match(pattern, query, memtable, row_limit, true)?
}
QueryEntry::Search {
vector,
top_k,
expand,
} => execute_search(vector, *top_k, expand.as_ref(), query, memtable, row_limit)?,
};
if !query.order_by.is_empty() {
sort_results(&mut results, &query.order_by, memtable);
}
if let Some(off) = query.offset {
if off < results.len() {
results = results.into_iter().skip(off).collect();
} else {
results.clear();
}
}
if let Some(lim) = query.limit {
results.truncate(lim);
}
results = apply_aggregation_and_distinct(&query.returns, results, memtable)?;
apply_projection_pruning(&query.returns, &mut results);
Ok(results)
}
fn execute_find<T: VectorType>(
filter: &Filter,
query: &TqlQuery,
mt: &MemTable<T>,
row_limit: usize,
) -> Result<TqlResult<T>, TriviumError> {
let mut results = Vec::new();
for id in mt.all_node_ids() {
if results.len() >= row_limit {
break;
}
let payload = match mt.get_payload(id) {
Some(p) => p,
None => continue,
};
if !filter.matches(payload) {
continue;
}
if let Some(pred) = &query.predicate
&& !eval_predicate_single(pred, id, mt)
{
continue;
}
let node = match build_node(id, mt) {
Some(n) => n,
None => continue,
};
let mut row = HashMap::new();
match &query.returns {
ReturnClause::All => {
row.insert("_".into(), node);
}
ReturnClause::Variables(vars) => {
if let Some(var) = vars.first() {
row.insert(var.clone(), node);
}
}
ReturnClause::Expressions(exprs) => {
if let Some(var) = extract_first_var_from_exprs(exprs) {
row.insert(var, node);
} else {
row.insert("_".into(), node);
}
}
}
results.push(row);
}
Ok(results)
}
fn execute_match<T: VectorType>(
pattern: &TqlPattern,
query: &TqlQuery,
mt: &MemTable<T>,
row_limit: usize,
_optional: bool,
) -> Result<TqlResult<T>, TriviumError> {
let mut var_map: HashMap<String, usize> = HashMap::new();
for node_pat in &pattern.nodes {
if let Some(var) = &node_pat.var {
let next_idx = var_map.len();
var_map.entry(var.clone()).or_insert(next_idx);
}
}
let return_vars: Vec<String> = match &query.returns {
ReturnClause::All => var_map.keys().cloned().collect(),
ReturnClause::Variables(vars) => vars.clone(),
ReturnClause::Expressions(exprs) => extract_vars_from_exprs(exprs),
};
for var in &return_vars {
let next_idx = var_map.len();
var_map.entry(var.clone()).or_insert(next_idx);
}
let start_candidates =
find_tql_candidates_optimized(&pattern.nodes[0], pattern.edges.first(), mt);
let mut results = Vec::new();
let mut budget: usize = 0;
for start_id in start_candidates {
let mut env = vec![None; var_map.len()];
let cont = tql_dfs(
mt,
pattern,
query.predicate.as_ref(),
&return_vars,
&var_map,
0, &mut env,
start_id,
&mut results,
&mut budget,
MAX_BUDGET,
row_limit,
)?;
if !cont {
break;
}
}
Ok(results)
}
fn tql_dfs<T: VectorType>(
mt: &MemTable<T>,
pattern: &TqlPattern,
predicate: Option<&Predicate>,
return_vars: &[String],
var_map: &HashMap<String, usize>,
layer_idx: usize,
env: &mut Vec<Option<u64>>,
current: u64,
results: &mut TqlResult<T>,
budget: &mut usize,
max_budget: usize,
row_limit: usize,
) -> Result<bool, TriviumError> {
*budget += 1;
if *budget > max_budget {
return Err(TriviumError::QueryExecution(format!(
"Query exceeded budget of {} steps",
max_budget
)));
}
let node_pat = &pattern.nodes[layer_idx];
if let Some(filter) = &node_pat.filter
&& !matches_filter_with_id(filter, current, mt)
{
return Ok(true); }
let old_val = if let Some(var) = &node_pat.var {
let idx = var_map[var];
let old = env[idx];
env[idx] = Some(current);
Some((idx, old))
} else {
None
};
if layer_idx == pattern.edges.len() {
let passed = match predicate {
Some(pred) => eval_predicate_env(pred, env, var_map, mt),
None => true,
};
if passed {
let mut row = HashMap::new();
for var in return_vars {
if let Some(&idx) = var_map.get(var)
&& let Some(id) = env[idx]
&& let Some(node) = build_node(id, mt)
{
row.insert(var.clone(), node);
}
}
results.push(row);
if results.len() >= row_limit {
if let Some((idx, old)) = old_val {
env[idx] = old;
}
return Ok(false);
}
}
} else {
let edge_pat = &pattern.edges[layer_idx];
if let Some(hop) = &edge_pat.hop_range {
let mut visited = HashSet::new();
visited.insert(current);
let cont = tql_dfs_variable_length(
mt,
pattern,
predicate,
return_vars,
var_map,
layer_idx,
env,
current,
&edge_pat.labels,
hop.min,
hop.max,
0,
&mut visited,
results,
budget,
max_budget,
row_limit,
edge_pat.direction,
)?;
if !cont {
if let Some((idx, old)) = old_val {
env[idx] = old;
}
return Ok(false);
}
} else {
let neighbors = collect_neighbors(mt, current, &edge_pat.labels, edge_pat.direction);
for next_id in neighbors {
let cont = tql_dfs(
mt,
pattern,
predicate,
return_vars,
var_map,
layer_idx + 1,
env,
next_id,
results,
budget,
max_budget,
row_limit,
)?;
if !cont {
if let Some((idx, old)) = old_val {
env[idx] = old;
}
return Ok(false);
}
}
}
}
if let Some((idx, old)) = old_val {
env[idx] = old;
}
Ok(true)
}
fn tql_dfs_variable_length<T: VectorType>(
mt: &MemTable<T>,
pattern: &TqlPattern,
predicate: Option<&Predicate>,
return_vars: &[String],
var_map: &HashMap<String, usize>,
layer_idx: usize,
env: &mut Vec<Option<u64>>,
current: u64,
labels: &[String],
min_depth: usize,
max_depth: usize,
current_depth: usize,
visited: &mut HashSet<u64>,
results: &mut TqlResult<T>,
budget: &mut usize,
max_budget: usize,
row_limit: usize,
direction: EdgeDirection,
) -> Result<bool, TriviumError> {
if current_depth >= min_depth {
let cont = tql_dfs(
mt,
pattern,
predicate,
return_vars,
var_map,
layer_idx + 1,
env,
current,
results,
budget,
max_budget,
row_limit,
)?;
if !cont {
return Ok(false);
}
}
if current_depth < max_depth {
let neighbors = collect_neighbors(mt, current, labels, direction);
for next in neighbors {
if visited.contains(&next) {
continue;
}
visited.insert(next);
let cont = tql_dfs_variable_length(
mt,
pattern,
predicate,
return_vars,
var_map,
layer_idx,
env,
next,
labels,
min_depth,
max_depth,
current_depth + 1,
visited,
results,
budget,
max_budget,
row_limit,
direction,
)?;
visited.remove(&next);
if !cont {
return Ok(false);
}
}
}
Ok(true)
}
fn collect_neighbors<T: VectorType>(
mt: &MemTable<T>,
current: u64,
labels: &[String],
direction: EdgeDirection,
) -> Vec<u64> {
let mut neighbors = Vec::new();
if (direction == EdgeDirection::Forward || direction == EdgeDirection::Both)
&& let Some(edges) = mt.get_edges(current)
{
for edge in edges {
if !labels.is_empty() && !labels.contains(&edge.label) {
continue;
}
neighbors.push(edge.target_id);
}
}
if direction == EdgeDirection::Backward || direction == EdgeDirection::Both {
for &src_id in mt.get_incoming_sources(current) {
if labels.is_empty() {
neighbors.push(src_id);
} else if let Some(edges) = mt.get_edges(src_id) {
for edge in edges {
if edge.target_id == current && labels.contains(&edge.label) {
neighbors.push(src_id);
break;
}
}
}
}
}
neighbors
}
fn execute_search<T: VectorType>(
vector: &[f64],
top_k: usize,
expand: Option<&ExpandClause>,
query: &TqlQuery,
mt: &MemTable<T>,
row_limit: usize,
) -> Result<TqlResult<T>, TriviumError> {
let query_vec: Vec<T> = vector.iter().map(|v| T::from_f32(*v as f32)).collect();
let dim = query_vec.len();
let mut scored: Vec<(NodeId, f32)> = mt
.all_node_ids()
.iter()
.filter_map(|&id| {
let vec = mt.get_vector(id)?;
if vec.len() != dim {
return None;
}
let score = T::similarity(&query_vec, vec);
Some((id, score))
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
let mut candidates: Vec<NodeId> = scored.iter().map(|s| s.0).collect();
if let Some(ex) = expand {
let mut expanded = HashSet::new();
for &seed in &candidates {
expanded.insert(seed);
}
for &seed in &candidates.clone() {
let neighbors = crate::graph::pathfinding::k_hop_neighbors(
mt,
seed,
ex.max_depth,
if ex.labels.len() == 1 {
Some(ex.labels[0].as_str())
} else {
None
},
);
for (&nid, &dist) in &neighbors {
if dist >= ex.min_depth {
expanded.insert(nid);
}
}
}
candidates = expanded.into_iter().collect();
}
let mut results = Vec::new();
for id in candidates {
if results.len() >= row_limit {
break;
}
if let Some(pred) = &query.predicate
&& !eval_predicate_single(pred, id, mt)
{
continue;
}
if let Some(node) = build_node(id, mt) {
let mut row = HashMap::new();
row.insert("_".into(), node);
results.push(row);
}
}
Ok(results)
}
fn eval_predicate_env<T: VectorType>(
pred: &Predicate,
env: &[Option<u64>],
var_map: &HashMap<String, usize>,
mt: &MemTable<T>,
) -> bool {
match pred {
Predicate::Compare { left, op, right } => {
let lval = eval_tql_expr(left, env, var_map, mt);
let rval = eval_tql_expr(right, env, var_map, mt);
compare_runtime(&lval, op, &rval)
}
Predicate::DocFilter { var, filter } => {
let id = match var {
Some(v) => {
if let Some(&idx) = var_map.get(v) {
env[idx]
} else {
None
}
}
None => {
env.iter().find(|o| o.is_some()).copied().flatten()
}
};
match id {
Some(nid) => match mt.get_payload(nid) {
Some(payload) => filter.matches(payload),
None => false,
},
None => false,
}
}
Predicate::And(a, b) => {
eval_predicate_env(a, env, var_map, mt) && eval_predicate_env(b, env, var_map, mt)
}
Predicate::Or(a, b) => {
eval_predicate_env(a, env, var_map, mt) || eval_predicate_env(b, env, var_map, mt)
}
Predicate::Not(inner) => !eval_predicate_env(inner, env, var_map, mt),
}
}
fn eval_predicate_single<T: VectorType>(pred: &Predicate, id: NodeId, mt: &MemTable<T>) -> bool {
match pred {
Predicate::Compare { left, op, right } => {
let lval = eval_tql_expr_single(left, id, mt);
let rval = eval_tql_expr_single(right, id, mt);
compare_runtime(&lval, op, &rval)
}
Predicate::DocFilter { filter, .. } => match mt.get_payload(id) {
Some(payload) => filter.matches(payload),
None => false,
},
Predicate::And(a, b) => {
eval_predicate_single(a, id, mt) && eval_predicate_single(b, id, mt)
}
Predicate::Or(a, b) => eval_predicate_single(a, id, mt) || eval_predicate_single(b, id, mt),
Predicate::Not(inner) => !eval_predicate_single(inner, id, mt),
}
}
#[derive(Debug, Clone)]
enum RuntimeValue {
Int(i64),
Float(f64),
Str(String),
Bool(bool),
Null,
}
fn eval_tql_expr<T: VectorType>(
expr: &TqlExpr,
env: &[Option<u64>],
var_map: &HashMap<String, usize>,
mt: &MemTable<T>,
) -> RuntimeValue {
match expr {
TqlExpr::Property { var, field } => {
if let Some(&idx) = var_map.get(var)
&& let Some(id) = env[idx]
{
if field == "id" {
return RuntimeValue::Int(id as i64);
}
if let Some(payload) = mt.get_payload(id) {
return json_to_runtime(&payload[field]);
}
}
RuntimeValue::Null
}
TqlExpr::Literal(lit) => lit_to_runtime(lit),
}
}
fn eval_tql_expr_single<T: VectorType>(
expr: &TqlExpr,
id: NodeId,
mt: &MemTable<T>,
) -> RuntimeValue {
match expr {
TqlExpr::Property { field, .. } => {
if field == "id" {
return RuntimeValue::Int(id as i64);
}
if let Some(payload) = mt.get_payload(id) {
return json_to_runtime(&payload[field]);
}
RuntimeValue::Null
}
TqlExpr::Literal(lit) => lit_to_runtime(lit),
}
}
fn json_to_runtime(v: &serde_json::Value) -> RuntimeValue {
match v {
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
RuntimeValue::Int(i)
} else {
RuntimeValue::Float(n.as_f64().unwrap_or(0.0))
}
}
serde_json::Value::String(s) => RuntimeValue::Str(s.clone()),
serde_json::Value::Bool(b) => RuntimeValue::Bool(*b),
_ => RuntimeValue::Null,
}
}
fn lit_to_runtime(lit: &TqlLiteral) -> RuntimeValue {
match lit {
TqlLiteral::Int(n) => RuntimeValue::Int(*n),
TqlLiteral::Float(f) => RuntimeValue::Float(*f),
TqlLiteral::Str(s) => RuntimeValue::Str(s.clone()),
TqlLiteral::Bool(b) => RuntimeValue::Bool(*b),
TqlLiteral::Null => RuntimeValue::Null,
}
}
fn compare_runtime(lhs: &RuntimeValue, op: &TqlCompOp, rhs: &RuntimeValue) -> bool {
match (lhs, rhs) {
(RuntimeValue::Int(a), RuntimeValue::Int(b)) => cmp_ord(a, op, b),
(RuntimeValue::Float(a), RuntimeValue::Float(b)) => cmp_f64(*a, op, *b),
(RuntimeValue::Int(a), RuntimeValue::Float(b)) => cmp_f64(*a as f64, op, *b),
(RuntimeValue::Float(a), RuntimeValue::Int(b)) => cmp_f64(*a, op, *b as f64),
(RuntimeValue::Str(a), RuntimeValue::Str(b)) => cmp_ord(a, op, b),
(RuntimeValue::Bool(a), RuntimeValue::Bool(b)) => match op {
TqlCompOp::Eq => a == b,
TqlCompOp::Ne => a != b,
_ => false,
},
_ => false,
}
}
fn cmp_ord<T: Ord>(a: &T, op: &TqlCompOp, b: &T) -> bool {
match op {
TqlCompOp::Eq => a == b,
TqlCompOp::Ne => a != b,
TqlCompOp::Gt => a > b,
TqlCompOp::Gte => a >= b,
TqlCompOp::Lt => a < b,
TqlCompOp::Lte => a <= b,
}
}
fn cmp_f64(a: f64, op: &TqlCompOp, b: f64) -> bool {
match op {
TqlCompOp::Eq => (a - b).abs() < f64::EPSILON,
TqlCompOp::Ne => (a - b).abs() >= f64::EPSILON,
TqlCompOp::Gt => a > b,
TqlCompOp::Gte => a >= b,
TqlCompOp::Lt => a < b,
TqlCompOp::Lte => a <= b,
}
}
fn find_tql_candidates<T: VectorType>(node_pat: &TqlNodePattern, mt: &MemTable<T>) -> Vec<NodeId> {
if let Some(filter) = &node_pat.filter {
if let Some(target_id) = extract_id_from_filter(filter) {
if mt.contains(target_id) && matches_filter_with_id(filter, target_id, mt) {
return vec![target_id];
}
return vec![];
}
if let Some(indexed_ids) = try_property_index_lookup(filter, mt) {
let mut candidates = Vec::new();
for id in indexed_ids {
if matches_filter_with_id(filter, id, mt) {
candidates.push(id);
}
}
return candidates;
}
}
let all_ids = mt.all_node_ids();
let mut candidates = Vec::new();
for id in all_ids {
if let Some(filter) = &node_pat.filter
&& !matches_filter_with_id(filter, id, mt)
{
continue;
}
candidates.push(id);
}
candidates
}
fn try_property_index_lookup<T: VectorType>(
filter: &Filter,
mt: &MemTable<T>,
) -> Option<Vec<NodeId>> {
match filter {
Filter::Eq(key, val) if key != "id" => {
mt.find_by_property_index(key, val).map(|ids| ids.to_vec())
}
Filter::And(filters) => {
for f in filters {
if let Filter::Eq(key, val) = f
&& key != "id"
&& let Some(ids) = mt.find_by_property_index(key, val)
{
return Some(ids.to_vec());
}
}
None
}
_ => None,
}
}
fn find_tql_candidates_optimized<T: VectorType>(
node_pat: &TqlNodePattern,
first_edge: Option<&TqlEdgePattern>,
mt: &MemTable<T>,
) -> Vec<NodeId> {
if node_pat.filter.is_some() {
return find_tql_candidates(node_pat, mt);
}
if let Some(edge_pat) = first_edge
&& !edge_pat.labels.is_empty()
{
let mut src_set: HashSet<NodeId> = HashSet::new();
for label in &edge_pat.labels {
for &(src, _dst) in mt.get_edges_by_label(label) {
src_set.insert(src);
}
}
let mut candidates: Vec<NodeId> = src_set.into_iter().collect();
candidates.sort_unstable(); return candidates;
}
find_tql_candidates(node_pat, mt)
}
fn extract_id_from_filter(filter: &Filter) -> Option<NodeId> {
match filter {
Filter::Eq(key, val) if key == "id" => val.as_i64().map(|n| n as NodeId),
Filter::And(filters) => {
for f in filters {
if let Some(id) = extract_id_from_filter(f) {
return Some(id);
}
}
None
}
_ => None,
}
}
fn matches_filter_with_id<T: VectorType>(
filter: &Filter,
node_id: NodeId,
mt: &MemTable<T>,
) -> bool {
match filter {
Filter::Eq(key, val) if key == "id" => {
val.as_i64().is_some_and(|target| node_id == target as u64)
}
Filter::And(filters) => filters
.iter()
.all(|f| matches_filter_with_id(f, node_id, mt)),
Filter::Or(filters) => filters
.iter()
.any(|f| matches_filter_with_id(f, node_id, mt)),
_ => match mt.get_payload(node_id) {
Some(payload) => filter.matches(payload),
None => false,
},
}
}
fn sort_results<T: VectorType>(
results: &mut TqlResult<T>,
order_by: &[OrderExpr],
_mt: &MemTable<T>,
) {
results.sort_by(|a, b| {
for order in order_by {
let a_val = extract_order_key(&order.expr, a);
let b_val = extract_order_key(&order.expr, b);
let cmp = compare_for_sort(&a_val, &b_val);
let cmp = if order.descending { cmp.reverse() } else { cmp };
if cmp != std::cmp::Ordering::Equal {
return cmp;
}
}
std::cmp::Ordering::Equal
});
}
fn extract_order_key<T>(expr: &TqlExpr, row: &HashMap<String, Node<T>>) -> RuntimeValue {
match expr {
TqlExpr::Property { var, field } => {
if let Some(node) = row.get(var) {
if field == "id" {
return RuntimeValue::Int(node.id as i64);
}
return json_to_runtime(&node.payload[field]);
}
if let Some(node) = row.get("_") {
if field == "id" {
return RuntimeValue::Int(node.id as i64);
}
return json_to_runtime(&node.payload[field]);
}
RuntimeValue::Null
}
TqlExpr::Literal(lit) => lit_to_runtime(lit),
}
}
fn compare_for_sort(a: &RuntimeValue, b: &RuntimeValue) -> std::cmp::Ordering {
match (a, b) {
(RuntimeValue::Int(a), RuntimeValue::Int(b)) => a.cmp(b),
(RuntimeValue::Float(a), RuntimeValue::Float(b)) => {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
}
(RuntimeValue::Int(a), RuntimeValue::Float(b)) => (*a as f64)
.partial_cmp(b)
.unwrap_or(std::cmp::Ordering::Equal),
(RuntimeValue::Float(a), RuntimeValue::Int(b)) => a
.partial_cmp(&(*b as f64))
.unwrap_or(std::cmp::Ordering::Equal),
(RuntimeValue::Str(a), RuntimeValue::Str(b)) => a.cmp(b),
(RuntimeValue::Null, RuntimeValue::Null) => std::cmp::Ordering::Equal,
(RuntimeValue::Null, _) => std::cmp::Ordering::Greater, (_, RuntimeValue::Null) => std::cmp::Ordering::Less,
_ => std::cmp::Ordering::Equal,
}
}
fn build_node<T: VectorType>(id: NodeId, mt: &MemTable<T>) -> Option<Node<T>> {
let vector = mt.get_vector(id)?;
let payload = mt.get_payload(id)?;
let edges = mt.get_edges(id).map(|e| e.to_vec()).unwrap_or_default();
Some(Node {
id,
vector: vector.to_vec(),
payload: payload.clone(),
edges,
})
}
fn extract_vars_from_exprs(exprs: &[ReturnExpr]) -> Vec<String> {
let mut vars = Vec::new();
for expr in exprs {
collect_vars_from_kind(&expr.kind, &mut vars);
}
vars.dedup();
vars
}
fn extract_first_var_from_exprs(exprs: &[ReturnExpr]) -> Option<String> {
for expr in exprs {
if let Some(var) = first_var_from_kind(&expr.kind) {
return Some(var);
}
}
None
}
fn collect_vars_from_kind(kind: &ReturnExprKind, out: &mut Vec<String>) {
match kind {
ReturnExprKind::Var(v) => {
if !out.contains(v) {
out.push(v.clone());
}
}
ReturnExprKind::Property(v, _) => {
if !out.contains(v) {
out.push(v.clone());
}
}
ReturnExprKind::Aggregate(_, inner) => collect_vars_from_kind(inner, out),
}
}
fn first_var_from_kind(kind: &ReturnExprKind) -> Option<String> {
match kind {
ReturnExprKind::Var(v) => Some(v.clone()),
ReturnExprKind::Property(v, _) => Some(v.clone()),
ReturnExprKind::Aggregate(_, inner) => first_var_from_kind(inner),
}
}
fn apply_aggregation_and_distinct<T: VectorType>(
returns: &ReturnClause,
results: TqlResult<T>,
_mt: &MemTable<T>,
) -> Result<TqlResult<T>, TriviumError> {
let exprs = match returns {
ReturnClause::Expressions(exprs) => exprs,
_ => return Ok(results),
};
let has_agg = exprs.iter().any(|e| is_aggregate(&e.kind));
let has_distinct = exprs.iter().any(|e| e.distinct);
if !has_agg && !has_distinct {
return Ok(results);
}
if !has_agg && has_distinct {
return Ok(apply_distinct(results, exprs));
}
Ok(apply_aggregation(results, exprs))
}
fn is_aggregate(kind: &ReturnExprKind) -> bool {
matches!(kind, ReturnExprKind::Aggregate(_, _))
}
fn apply_distinct<T: VectorType>(results: TqlResult<T>, _exprs: &[ReturnExpr]) -> TqlResult<T> {
let mut seen: HashSet<String> = HashSet::new();
let mut out = Vec::new();
for row in results {
let sig = row_signature(&row);
if seen.insert(sig) {
out.push(row);
}
}
out
}
fn row_signature<T: VectorType>(row: &HashMap<String, Node<T>>) -> String {
let mut parts: Vec<_> = row.iter().map(|(k, v)| format!("{}:{}", k, v.id)).collect();
parts.sort();
parts.join("|")
}
fn apply_aggregation<T: VectorType>(results: TqlResult<T>, exprs: &[ReturnExpr]) -> TqlResult<T> {
if results.is_empty() {
return Vec::new();
}
let group_exprs: Vec<&ReturnExpr> = exprs.iter().filter(|e| !is_aggregate(&e.kind)).collect();
let agg_exprs: Vec<&ReturnExpr> = exprs.iter().filter(|e| is_aggregate(&e.kind)).collect();
let mut groups: BTreeMap<String, Vec<&HashMap<String, Node<T>>>> = BTreeMap::new();
for row in &results {
let key = make_group_key(row, &group_exprs);
groups.entry(key).or_default().push(row);
}
let mut output = Vec::new();
for rows in groups.values() {
let mut result_row: HashMap<String, Node<T>> = HashMap::new();
if let Some(first_row) = rows.first() {
for expr in &group_exprs {
if let Some(var) = first_var_from_kind(&expr.kind)
&& let Some(node) = first_row.get(&var)
{
result_row.insert(var, node.clone());
}
}
}
for agg_expr in &agg_exprs {
if let ReturnExprKind::Aggregate(func, inner) = &agg_expr.kind {
let alias = agg_expr
.alias
.clone()
.unwrap_or_else(|| format!("{:?}", func).to_lowercase());
let agg_val = compute_aggregate(*func, inner, rows);
let agg_node = Node {
id: 0,
vector: Vec::new(),
payload: serde_json::json!({ &alias: agg_val }),
edges: Vec::new(),
};
result_row.insert(alias, agg_node);
}
}
output.push(result_row);
}
output
}
fn make_group_key<T: VectorType>(
row: &HashMap<String, Node<T>>,
group_exprs: &[&ReturnExpr],
) -> String {
let mut parts = Vec::new();
for expr in group_exprs {
match &expr.kind {
ReturnExprKind::Var(v) => {
if let Some(node) = row.get(v) {
parts.push(format!("{}:{}", v, node.id));
}
}
ReturnExprKind::Property(v, field) => {
if let Some(node) = row.get(v) {
let val = node
.payload
.get(field)
.cloned()
.unwrap_or(serde_json::Value::Null);
parts.push(format!("{}.{}={}", v, field, val));
}
}
_ => {}
}
}
parts.join("|")
}
fn compute_aggregate<T: VectorType>(
func: AggFunc,
inner: &ReturnExprKind,
rows: &[&HashMap<String, Node<T>>],
) -> serde_json::Value {
match func {
AggFunc::Count => {
let count = rows
.iter()
.filter(|r| resolve_inner(inner, r).is_some())
.count();
serde_json::json!(count)
}
AggFunc::Sum => {
let sum: f64 = rows
.iter()
.filter_map(|r| resolve_inner_numeric(inner, r))
.sum();
serde_json::json!(sum)
}
AggFunc::Avg => {
let vals: Vec<f64> = rows
.iter()
.filter_map(|r| resolve_inner_numeric(inner, r))
.collect();
if vals.is_empty() {
serde_json::Value::Null
} else {
serde_json::json!(vals.iter().sum::<f64>() / vals.len() as f64)
}
}
AggFunc::Min => {
let min = rows
.iter()
.filter_map(|r| resolve_inner_numeric(inner, r))
.fold(f64::INFINITY, f64::min);
if min.is_infinite() {
serde_json::Value::Null
} else {
serde_json::json!(min)
}
}
AggFunc::Max => {
let max = rows
.iter()
.filter_map(|r| resolve_inner_numeric(inner, r))
.fold(f64::NEG_INFINITY, f64::max);
if max.is_infinite() {
serde_json::Value::Null
} else {
serde_json::json!(max)
}
}
AggFunc::Collect => {
let vals: Vec<serde_json::Value> = rows
.iter()
.filter_map(|r| resolve_inner_json(inner, r))
.collect();
serde_json::json!(vals)
}
}
}
fn resolve_inner<'a, T: VectorType>(
inner: &ReturnExprKind,
row: &'a HashMap<String, Node<T>>,
) -> Option<&'a Node<T>> {
match inner {
ReturnExprKind::Var(v) => row.get(v),
ReturnExprKind::Property(v, _) => row.get(v),
_ => None,
}
}
fn resolve_inner_numeric<T: VectorType>(
inner: &ReturnExprKind,
row: &HashMap<String, Node<T>>,
) -> Option<f64> {
match inner {
ReturnExprKind::Var(v) => {
row.get(v).map(|_| 1.0)
}
ReturnExprKind::Property(v, field) => row
.get(v)
.and_then(|node| node.payload.get(field).and_then(|v| v.as_f64())),
_ => None,
}
}
fn resolve_inner_json<T: VectorType>(
inner: &ReturnExprKind,
row: &HashMap<String, Node<T>>,
) -> Option<serde_json::Value> {
match inner {
ReturnExprKind::Var(v) => row.get(v).map(|n| serde_json::json!(n.id)),
ReturnExprKind::Property(v, field) => {
row.get(v).and_then(|node| node.payload.get(field).cloned())
}
_ => None,
}
}
fn generate_explain_plan<T: VectorType>(query: &TqlQuery, mt: &MemTable<T>) -> TqlResult<T> {
let mut plan = serde_json::Map::new();
let (entry_type, entry_detail) = match &query.entry {
QueryEntry::Find { filter } => ("FIND".to_string(), format!("{:?}", filter)),
QueryEntry::Match { pattern } => {
let detail = format_pattern_detail(pattern);
("MATCH".to_string(), detail)
}
QueryEntry::OptionalMatch { pattern } => {
let detail = format_pattern_detail(pattern);
("OPTIONAL MATCH".to_string(), detail)
}
QueryEntry::Search { top_k, expand, .. } => {
let detail = if expand.is_some() {
format!("TOP {} + EXPAND", top_k)
} else {
format!("TOP {}", top_k)
};
("SEARCH".to_string(), detail)
}
};
plan.insert("entry".into(), serde_json::json!(entry_type));
plan.insert("detail".into(), serde_json::json!(entry_detail));
let strategy = match &query.entry {
QueryEntry::Find { .. } => "full_scan".to_string(),
QueryEntry::Match { pattern } | QueryEntry::OptionalMatch { pattern } => {
analyze_candidate_strategy(&pattern.nodes[0], pattern.edges.first(), mt)
}
QueryEntry::Search { .. } => "vector_index".to_string(),
};
plan.insert("candidate_strategy".into(), serde_json::json!(strategy));
if let Some(pred) = &query.predicate {
plan.insert("predicate".into(), serde_json::json!(format!("{:?}", pred)));
} else {
plan.insert("predicate".into(), serde_json::json!("none"));
}
let return_info = match &query.returns {
ReturnClause::All => "ALL (*)".to_string(),
ReturnClause::Variables(vars) => format!("variables: [{}]", vars.join(", ")),
ReturnClause::Expressions(exprs) => {
let descs: Vec<String> = exprs.iter().map(format_return_expr).collect();
format!("expressions: [{}]", descs.join(", "))
}
};
plan.insert("return".into(), serde_json::json!(return_info));
let mut optimizations = Vec::new();
if strategy.contains("id_shortcut") {
optimizations.push("ID O(1) shortcut");
}
if strategy.contains("label_index") {
optimizations.push("label index pushdown");
}
if let ReturnClause::Expressions(exprs) = &query.returns {
let prunable = get_prunable_vars(exprs);
if !prunable.is_empty() {
optimizations.push("projection pruning");
}
if exprs.iter().any(|e| e.distinct) {
optimizations.push("DISTINCT dedup");
}
if exprs.iter().any(|e| is_aggregate(&e.kind)) {
optimizations.push("aggregation");
}
}
if query.limit.is_some() {
optimizations.push("LIMIT early termination");
}
plan.insert("optimizations".into(), serde_json::json!(optimizations));
plan.insert(
"total_nodes".into(),
serde_json::json!(mt.all_node_ids().len()),
);
if let Some(lim) = query.limit {
plan.insert("limit".into(), serde_json::json!(lim));
}
if let Some(off) = query.offset {
plan.insert("offset".into(), serde_json::json!(off));
}
if !query.order_by.is_empty() {
plan.insert(
"order_by_count".into(),
serde_json::json!(query.order_by.len()),
);
}
let plan_node = Node {
id: 0,
vector: Vec::new(),
payload: serde_json::Value::Object(plan),
edges: Vec::new(),
};
let mut row = HashMap::new();
row.insert("plan".to_string(), plan_node);
vec![row]
}
fn analyze_candidate_strategy<T: VectorType>(
node_pat: &TqlNodePattern,
first_edge: Option<&TqlEdgePattern>,
_mt: &MemTable<T>,
) -> String {
if let Some(filter) = &node_pat.filter {
if extract_id_from_filter(filter).is_some() {
return "id_shortcut O(1)".to_string();
}
return "filter_scan (with inline filter)".to_string();
}
if let Some(edge_pat) = first_edge
&& !edge_pat.labels.is_empty()
{
return format!(
"label_index pushdown (labels: [{}])",
edge_pat.labels.join(", ")
);
}
"full_scan O(N)".to_string()
}
fn format_pattern_detail(pattern: &TqlPattern) -> String {
let mut parts = Vec::new();
for (i, node) in pattern.nodes.iter().enumerate() {
let var = node.var.as_deref().unwrap_or("_");
let filter = if node.filter.is_some() {
" {filter}"
} else {
""
};
parts.push(format!("({}{})", var, filter));
if i < pattern.edges.len() {
let edge = &pattern.edges[i];
let labels = if edge.labels.is_empty() {
String::new()
} else {
format!(":{}", edge.labels.join("|"))
};
let hops = if let Some(hop) = &edge.hop_range {
format!("*{}..{}", hop.min, hop.max)
} else {
String::new()
};
parts.push(format!("-[{}{}]->", labels, hops));
}
}
parts.join("")
}
fn format_return_expr(expr: &ReturnExpr) -> String {
let mut s = String::new();
if expr.distinct {
s.push_str("DISTINCT ");
}
s.push_str(&format_return_expr_kind(&expr.kind));
if let Some(alias) = &expr.alias {
s.push_str(&format!(" AS {}", alias));
}
s
}
fn format_return_expr_kind(kind: &ReturnExprKind) -> String {
match kind {
ReturnExprKind::Var(v) => v.clone(),
ReturnExprKind::Property(v, f) => format!("{}.{}", v, f),
ReturnExprKind::Aggregate(func, inner) => {
format!("{:?}({})", func, format_return_expr_kind(inner))
}
}
}
fn apply_projection_pruning<T: VectorType>(returns: &ReturnClause, results: &mut TqlResult<T>) {
let exprs = match returns {
ReturnClause::Expressions(exprs) => exprs,
_ => return, };
let prunable = get_prunable_vars(exprs);
if prunable.is_empty() {
return;
}
for row in results.iter_mut() {
for var in &prunable {
if let Some(node) = row.get_mut(var) {
node.vector.clear();
node.edges.clear();
}
}
}
}
fn get_prunable_vars(exprs: &[ReturnExpr]) -> Vec<String> {
let mut full_vars: HashSet<String> = HashSet::new(); let mut prop_vars: HashSet<String> = HashSet::new();
for expr in exprs {
classify_vars(&expr.kind, &mut full_vars, &mut prop_vars);
}
prop_vars.difference(&full_vars).cloned().collect()
}
fn classify_vars(
kind: &ReturnExprKind,
full_vars: &mut HashSet<String>,
prop_vars: &mut HashSet<String>,
) {
match kind {
ReturnExprKind::Var(v) => {
full_vars.insert(v.clone());
}
ReturnExprKind::Property(v, _) => {
prop_vars.insert(v.clone());
}
ReturnExprKind::Aggregate(_, inner) => {
classify_vars(inner, full_vars, prop_vars);
}
}
}
pub fn execute_tql_mutation<T: VectorType>(
mutation: &TqlMutation,
mt: &MemTable<T>,
) -> Result<Vec<MutationOp<T>>, TriviumError> {
match &mutation.action {
MutationAction::Create(create) => execute_create(create, &mutation.source, mt),
MutationAction::Set(assignments) => execute_set(assignments, &mutation.source, mt),
MutationAction::Delete { vars, detach } => {
execute_delete(vars, *detach, &mutation.source, mt)
}
}
}
fn execute_create<T: VectorType>(
create: &CreateAction,
source: &Option<MutationSource>,
mt: &MemTable<T>,
) -> Result<Vec<MutationOp<T>>, TriviumError> {
let mut ops = Vec::new();
let dim = mt.dim();
let matched_vars: HashMap<String, NodeId> = if let Some(src) = source {
resolve_match_vars(&src.pattern, src.predicate.as_ref(), mt)?
} else {
HashMap::new()
};
for node in &create.nodes {
let var = node.var.as_deref().unwrap_or("_anon");
if matched_vars.contains_key(var) {
continue; }
let zero_vec = vec![T::default(); dim];
ops.push(MutationOp::InsertNode {
var: var.to_string(),
vector: zero_vec,
payload: node.payload.clone(),
});
}
for edge in &create.edges {
let src_id = matched_vars.get(&edge.src_var).copied();
let dst_id = matched_vars.get(&edge.dst_var).copied();
if let (Some(s), Some(d)) = (src_id, dst_id) {
ops.push(MutationOp::LinkEdge {
src_id: s,
dst_id: d,
label: edge.label.clone(),
weight: edge.weight,
});
}
else {
ops.push(MutationOp::LinkEdge {
src_id: src_id.unwrap_or(0),
dst_id: dst_id.unwrap_or(0),
label: edge.label.clone(),
weight: edge.weight,
});
}
}
Ok(ops)
}
fn execute_set<T: VectorType>(
assignments: &[SetAssignment],
source: &Option<MutationSource>,
mt: &MemTable<T>,
) -> Result<Vec<MutationOp<T>>, TriviumError> {
let source = source
.as_ref()
.ok_or_else(|| TriviumError::QueryParse("SET requires a preceding MATCH clause".into()))?;
let query = build_match_query(&source.pattern, source.predicate.as_ref());
let results = execute_tql(&query, mt)?;
let mut ops = Vec::new();
for row in &results {
for assign in assignments {
if let Some(node) = row.get(&assign.var) {
let mut new_payload = node.payload.clone();
if let Some(obj) = new_payload.as_object_mut() {
obj.insert(assign.field.clone(), assign.value.clone());
}
ops.push(MutationOp::UpdatePayload {
id: node.id,
payload: new_payload,
});
}
}
}
Ok(ops)
}
fn execute_delete<T: VectorType>(
vars: &[String],
detach: bool,
source: &Option<MutationSource>,
mt: &MemTable<T>,
) -> Result<Vec<MutationOp<T>>, TriviumError> {
let source = source.as_ref().ok_or_else(|| {
TriviumError::QueryParse("DELETE requires a preceding MATCH clause".into())
})?;
let query = build_match_query(&source.pattern, source.predicate.as_ref());
let results = execute_tql(&query, mt)?;
let mut ops = Vec::new();
let mut deleted: HashSet<NodeId> = HashSet::new();
for row in &results {
for var in vars {
if let Some(node) = row.get(var)
&& deleted.insert(node.id)
{
ops.push(MutationOp::DeleteNode {
id: node.id,
detach,
});
}
}
}
Ok(ops)
}
fn resolve_match_vars<T: VectorType>(
pattern: &TqlPattern,
predicate: Option<&Predicate>,
mt: &MemTable<T>,
) -> Result<HashMap<String, NodeId>, TriviumError> {
let query = build_match_query(pattern, predicate);
let results = execute_tql(&query, mt)?;
let mut var_ids = HashMap::new();
if let Some(first_row) = results.first() {
for (var, node) in first_row {
var_ids.insert(var.clone(), node.id);
}
}
Ok(var_ids)
}
fn build_match_query(pattern: &TqlPattern, predicate: Option<&Predicate>) -> TqlQuery {
TqlQuery {
explain: false,
entry: QueryEntry::Match {
pattern: pattern.clone(),
},
predicate: predicate.cloned(),
returns: ReturnClause::All,
order_by: Vec::new(),
limit: None,
offset: None,
}
}