use crate::ast::*;
use mentedb_core::edge::EdgeType;
use mentedb_core::error::{MenteError, MenteResult};
use mentedb_core::types::{MemoryId, Timestamp};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum QueryPlan {
VectorSearch {
query: Vec<f32>,
k: usize,
filters: Vec<Filter>,
},
TagScan {
tags: Vec<String>,
filters: Vec<Filter>,
limit: Option<usize>,
},
TemporalScan {
start: Timestamp,
end: Timestamp,
filters: Vec<Filter>,
},
GraphTraversal {
start: MemoryId,
depth: usize,
edge_types: Vec<EdgeType>,
},
PointLookup {
id: MemoryId,
},
EdgeInsert {
source: MemoryId,
target: MemoryId,
edge_type: EdgeType,
weight: f32,
},
Delete {
id: MemoryId,
},
Consolidate {
filters: Vec<Filter>,
},
}
const DEFAULT_LIMIT: usize = 20;
pub fn plan(statement: &Statement) -> MenteResult<QueryPlan> {
match statement {
Statement::Recall(recall) => plan_recall(recall),
Statement::Relate(relate) => plan_relate(relate),
Statement::Forget(forget) => Ok(QueryPlan::Delete { id: forget.target }),
Statement::Consolidate(cons) => Ok(QueryPlan::Consolidate {
filters: cons.filters.clone(),
}),
Statement::Traverse(trav) => plan_traverse(trav),
}
}
fn plan_recall(recall: &RecallStatement) -> MenteResult<QueryPlan> {
let limit = recall.limit.unwrap_or(DEFAULT_LIMIT);
if let Some(ref vec) = recall.near {
return Ok(QueryPlan::VectorSearch {
query: vec.clone(),
k: limit,
filters: recall.filters.clone(),
});
}
if let Some(sim_filter) = recall.filters.iter().find(|f| f.op == Operator::SimilarTo) {
if let Value::Text(ref _text) = sim_filter.value {
let remaining: Vec<Filter> = recall
.filters
.iter()
.filter(|f| f.op != Operator::SimilarTo)
.cloned()
.collect();
return Ok(QueryPlan::VectorSearch {
query: Vec::new(), k: limit,
filters: remaining,
});
}
return Err(MenteError::Query(
"~> operator requires a text value on the right-hand side".into(),
));
}
let tag_filters: Vec<&Filter> = recall
.filters
.iter()
.filter(|f| f.field == Field::Tag)
.collect();
if !tag_filters.is_empty() && recall.filters.iter().all(|f| f.field == Field::Tag) {
let tags: Vec<String> = tag_filters
.iter()
.filter_map(|f| match &f.value {
Value::Text(t) => Some(t.clone()),
_ => None,
})
.collect();
return Ok(QueryPlan::TagScan {
tags,
filters: Vec::new(),
limit: Some(limit),
});
}
let time_filters: Vec<&Filter> = recall
.filters
.iter()
.filter(|f| {
(f.field == Field::Created || f.field == Field::Accessed)
&& matches!(
f.op,
Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
)
})
.collect();
if !time_filters.is_empty() {
let remaining: Vec<Filter> = recall
.filters
.iter()
.filter(|f| {
!((f.field == Field::Created || f.field == Field::Accessed)
&& matches!(
f.op,
Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
))
})
.cloned()
.collect();
let mut start: Timestamp = 0;
let mut end: Timestamp = u64::MAX;
for f in &time_filters {
if let Value::Text(ref s) = f.value {
let _ = s; }
if let Value::Integer(ts) = f.value {
let ts = ts as u64;
match f.op {
Operator::Gt | Operator::Gte => start = ts,
Operator::Lt | Operator::Lte => end = ts,
_ => {}
}
}
}
return Ok(QueryPlan::TemporalScan {
start,
end,
filters: remaining,
});
}
Ok(QueryPlan::TagScan {
tags: Vec::new(),
filters: recall.filters.clone(),
limit: Some(limit),
})
}
fn plan_relate(relate: &RelateStatement) -> MenteResult<QueryPlan> {
Ok(QueryPlan::EdgeInsert {
source: relate.source,
target: relate.target,
edge_type: relate.edge_type,
weight: relate.weight.unwrap_or(1.0),
})
}
fn plan_traverse(trav: &TraverseStatement) -> MenteResult<QueryPlan> {
Ok(QueryPlan::GraphTraversal {
start: trav.start,
depth: trav.depth,
edge_types: trav.edge_filter.clone().unwrap_or_default(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexer::tokenize;
use crate::parser::Parser;
fn plan_mql(input: &str) -> QueryPlan {
let tokens = tokenize(input).unwrap();
let stmt = Parser::parse(&tokens).unwrap();
plan(&stmt).unwrap()
}
#[test]
fn test_near_produces_vector_search() {
let qp = plan_mql("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 5");
match qp {
QueryPlan::VectorSearch { query, k, .. } => {
assert_eq!(query, vec![0.1, 0.2, 0.3]);
assert_eq!(k, 5);
}
_ => panic!("expected VectorSearch, got {:?}", qp),
}
}
#[test]
fn test_similar_to_produces_vector_search() {
let qp = plan_mql(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#);
match qp {
QueryPlan::VectorSearch { k, .. } => {
assert_eq!(k, 10);
}
_ => panic!("expected VectorSearch, got {:?}", qp),
}
}
#[test]
fn test_tag_filter_produces_tag_scan() {
let qp = plan_mql(r#"RECALL memories WHERE tag = "backend" LIMIT 5"#);
match qp {
QueryPlan::TagScan { tags, limit, .. } => {
assert_eq!(tags, vec!["backend".to_string()]);
assert_eq!(limit, Some(5));
}
_ => panic!("expected TagScan, got {:?}", qp),
}
}
#[test]
fn test_forget_produces_delete() {
let qp = plan_mql("FORGET 550e8400-e29b-41d4-a716-446655440000");
match qp {
QueryPlan::Delete { id } => {
assert_eq!(
id,
"550e8400-e29b-41d4-a716-446655440000"
.parse::<MemoryId>()
.unwrap()
);
}
_ => panic!("expected Delete, got {:?}", qp),
}
}
#[test]
fn test_traverse_produces_graph_traversal() {
let qp = plan_mql(
"TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
);
match qp {
QueryPlan::GraphTraversal {
depth, edge_types, ..
} => {
assert_eq!(depth, 3);
assert_eq!(edge_types, vec![EdgeType::Caused]);
}
_ => panic!("expected GraphTraversal, got {:?}", qp),
}
}
#[test]
fn test_relate_produces_edge_insert() {
let qp = plan_mql(
"RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.8",
);
match qp {
QueryPlan::EdgeInsert {
edge_type, weight, ..
} => {
assert_eq!(edge_type, EdgeType::Caused);
assert!((weight - 0.8).abs() < f32::EPSILON);
}
_ => panic!("expected EdgeInsert, got {:?}", qp),
}
}
}