1use crate::ast::*;
4use mentedb_core::edge::EdgeType;
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub enum QueryPlan {
12 VectorSearch {
13 query: Vec<f32>,
14 k: usize,
15 filters: Vec<Filter>,
16 },
17 TagScan {
18 tags: Vec<String>,
19 filters: Vec<Filter>,
20 limit: Option<usize>,
21 },
22 TemporalScan {
23 start: Timestamp,
24 end: Timestamp,
25 filters: Vec<Filter>,
26 },
27 GraphTraversal {
28 start: MemoryId,
29 depth: usize,
30 edge_types: Vec<EdgeType>,
31 },
32 PointLookup {
33 id: MemoryId,
34 },
35 EdgeInsert {
36 source: MemoryId,
37 target: MemoryId,
38 edge_type: EdgeType,
39 weight: f32,
40 },
41 Delete {
42 id: MemoryId,
43 },
44 Consolidate {
45 filters: Vec<Filter>,
46 },
47}
48
49const DEFAULT_LIMIT: usize = 20;
50
51pub fn plan(statement: &Statement) -> MenteResult<QueryPlan> {
53 match statement {
54 Statement::Recall(recall) => plan_recall(recall),
55 Statement::Relate(relate) => plan_relate(relate),
56 Statement::Forget(forget) => Ok(QueryPlan::Delete { id: forget.target }),
57 Statement::Consolidate(cons) => Ok(QueryPlan::Consolidate {
58 filters: cons.filters.clone(),
59 }),
60 Statement::Traverse(trav) => plan_traverse(trav),
61 }
62}
63
64fn plan_recall(recall: &RecallStatement) -> MenteResult<QueryPlan> {
65 let limit = recall.limit.unwrap_or(DEFAULT_LIMIT);
66
67 if let Some(ref vec) = recall.near {
69 return Ok(QueryPlan::VectorSearch {
70 query: vec.clone(),
71 k: limit,
72 filters: recall.filters.clone(),
73 });
74 }
75
76 if let Some(sim_filter) = recall.filters.iter().find(|f| f.op == Operator::SimilarTo) {
78 if let Value::Text(ref _text) = sim_filter.value {
79 let remaining: Vec<Filter> = recall
82 .filters
83 .iter()
84 .filter(|f| f.op != Operator::SimilarTo)
85 .cloned()
86 .collect();
87 return Ok(QueryPlan::VectorSearch {
88 query: Vec::new(), k: limit,
90 filters: remaining,
91 });
92 }
93 return Err(MenteError::Query(
95 "~> operator requires a text value on the right-hand side".into(),
96 ));
97 }
98
99 let tag_filters: Vec<&Filter> = recall
101 .filters
102 .iter()
103 .filter(|f| f.field == Field::Tag)
104 .collect();
105 if !tag_filters.is_empty() && recall.filters.iter().all(|f| f.field == Field::Tag) {
106 let tags: Vec<String> = tag_filters
107 .iter()
108 .filter_map(|f| match &f.value {
109 Value::Text(t) => Some(t.clone()),
110 _ => None,
111 })
112 .collect();
113 return Ok(QueryPlan::TagScan {
114 tags,
115 filters: Vec::new(),
116 limit: Some(limit),
117 });
118 }
119
120 let time_filters: Vec<&Filter> = recall
122 .filters
123 .iter()
124 .filter(|f| {
125 (f.field == Field::Created || f.field == Field::Accessed)
126 && matches!(
127 f.op,
128 Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
129 )
130 })
131 .collect();
132
133 if !time_filters.is_empty() {
134 let remaining: Vec<Filter> = recall
135 .filters
136 .iter()
137 .filter(|f| {
138 !((f.field == Field::Created || f.field == Field::Accessed)
139 && matches!(
140 f.op,
141 Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
142 ))
143 })
144 .cloned()
145 .collect();
146
147 let mut start: Timestamp = 0;
148 let mut end: Timestamp = u64::MAX;
149 for f in &time_filters {
150 if let Value::Text(ref s) = f.value {
151 let _ = s; }
155 if let Value::Integer(ts) = f.value {
156 let ts = ts as u64;
157 match f.op {
158 Operator::Gt | Operator::Gte => start = ts,
159 Operator::Lt | Operator::Lte => end = ts,
160 _ => {}
161 }
162 }
163 }
164
165 return Ok(QueryPlan::TemporalScan {
166 start,
167 end,
168 filters: remaining,
169 });
170 }
171
172 Ok(QueryPlan::TagScan {
174 tags: Vec::new(),
175 filters: recall.filters.clone(),
176 limit: Some(limit),
177 })
178}
179
180fn plan_relate(relate: &RelateStatement) -> MenteResult<QueryPlan> {
181 Ok(QueryPlan::EdgeInsert {
182 source: relate.source,
183 target: relate.target,
184 edge_type: relate.edge_type,
185 weight: relate.weight.unwrap_or(1.0),
186 })
187}
188
189fn plan_traverse(trav: &TraverseStatement) -> MenteResult<QueryPlan> {
190 Ok(QueryPlan::GraphTraversal {
191 start: trav.start,
192 depth: trav.depth,
193 edge_types: trav.edge_filter.clone().unwrap_or_default(),
194 })
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::lexer::tokenize;
201 use crate::parser::Parser;
202
203 fn plan_mql(input: &str) -> QueryPlan {
204 let tokens = tokenize(input).unwrap();
205 let stmt = Parser::parse(&tokens).unwrap();
206 plan(&stmt).unwrap()
207 }
208
209 #[test]
210 fn test_near_produces_vector_search() {
211 let qp = plan_mql("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 5");
212 match qp {
213 QueryPlan::VectorSearch { query, k, .. } => {
214 assert_eq!(query, vec![0.1, 0.2, 0.3]);
215 assert_eq!(k, 5);
216 }
217 _ => panic!("expected VectorSearch, got {:?}", qp),
218 }
219 }
220
221 #[test]
222 fn test_similar_to_produces_vector_search() {
223 let qp = plan_mql(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#);
224 match qp {
225 QueryPlan::VectorSearch { k, .. } => {
226 assert_eq!(k, 10);
227 }
228 _ => panic!("expected VectorSearch, got {:?}", qp),
229 }
230 }
231
232 #[test]
233 fn test_tag_filter_produces_tag_scan() {
234 let qp = plan_mql(r#"RECALL memories WHERE tag = "backend" LIMIT 5"#);
235 match qp {
236 QueryPlan::TagScan { tags, limit, .. } => {
237 assert_eq!(tags, vec!["backend".to_string()]);
238 assert_eq!(limit, Some(5));
239 }
240 _ => panic!("expected TagScan, got {:?}", qp),
241 }
242 }
243
244 #[test]
245 fn test_forget_produces_delete() {
246 let qp = plan_mql("FORGET 550e8400-e29b-41d4-a716-446655440000");
247 match qp {
248 QueryPlan::Delete { id } => {
249 assert_eq!(
250 id,
251 "550e8400-e29b-41d4-a716-446655440000"
252 .parse::<MemoryId>()
253 .unwrap()
254 );
255 }
256 _ => panic!("expected Delete, got {:?}", qp),
257 }
258 }
259
260 #[test]
261 fn test_traverse_produces_graph_traversal() {
262 let qp = plan_mql(
263 "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
264 );
265 match qp {
266 QueryPlan::GraphTraversal {
267 depth, edge_types, ..
268 } => {
269 assert_eq!(depth, 3);
270 assert_eq!(edge_types, vec![EdgeType::Caused]);
271 }
272 _ => panic!("expected GraphTraversal, got {:?}", qp),
273 }
274 }
275
276 #[test]
277 fn test_relate_produces_edge_insert() {
278 let qp = plan_mql(
279 "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.8",
280 );
281 match qp {
282 QueryPlan::EdgeInsert {
283 edge_type, weight, ..
284 } => {
285 assert_eq!(edge_type, EdgeType::Caused);
286 assert!((weight - 0.8).abs() < f32::EPSILON);
287 }
288 _ => panic!("expected EdgeInsert, got {:?}", qp),
289 }
290 }
291}