1use crate::ast::*;
4use mentedb_core::edge::EdgeType;
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub enum QueryPlan {
12 VectorSearch {
13 query: Vec<f32>,
14 k: usize,
15 filters: Vec<Filter>,
16 },
17 TagScan {
18 tags: Vec<String>,
19 filters: Vec<Filter>,
20 limit: Option<usize>,
21 },
22 TemporalScan {
23 start: Timestamp,
24 end: Timestamp,
25 filters: Vec<Filter>,
26 },
27 GraphTraversal {
28 start: MemoryId,
29 depth: usize,
30 edge_types: Vec<EdgeType>,
31 },
32 PointLookup {
33 id: MemoryId,
34 },
35 EdgeInsert {
36 source: MemoryId,
37 target: MemoryId,
38 edge_type: EdgeType,
39 weight: f32,
40 },
41 Delete {
42 id: MemoryId,
43 },
44 Consolidate {
45 filters: Vec<Filter>,
46 },
47}
48
49const DEFAULT_LIMIT: usize = 20;
50
51pub fn plan(statement: &Statement) -> MenteResult<QueryPlan> {
53 match statement {
54 Statement::Recall(recall) => plan_recall(recall),
55 Statement::Relate(relate) => plan_relate(relate),
56 Statement::Forget(forget) => Ok(QueryPlan::Delete { id: forget.target }),
57 Statement::Consolidate(cons) => Ok(QueryPlan::Consolidate {
58 filters: cons.filters.clone(),
59 }),
60 Statement::Traverse(trav) => plan_traverse(trav),
61 }
62}
63
64fn plan_recall(recall: &RecallStatement) -> MenteResult<QueryPlan> {
65 let limit = recall.limit.unwrap_or(DEFAULT_LIMIT);
66
67 if let Some(ref vec) = recall.near {
69 return Ok(QueryPlan::VectorSearch {
70 query: vec.clone(),
71 k: limit,
72 filters: recall.filters.clone(),
73 });
74 }
75
76 if let Some(sim_filter) = recall.filters.iter().find(|f| f.op == Operator::SimilarTo) {
78 if let Value::Text(ref _text) = sim_filter.value {
79 let remaining: Vec<Filter> = recall
82 .filters
83 .iter()
84 .filter(|f| f.op != Operator::SimilarTo)
85 .cloned()
86 .collect();
87 return Ok(QueryPlan::VectorSearch {
88 query: Vec::new(), k: limit,
90 filters: remaining,
91 });
92 }
93 return Err(MenteError::Query(
95 "~> operator requires a text value on the right-hand side".into(),
96 ));
97 }
98
99 let tag_filters: Vec<&Filter> = recall
101 .filters
102 .iter()
103 .filter(|f| f.field == Field::Tag)
104 .collect();
105 if !tag_filters.is_empty() && recall.filters.iter().all(|f| f.field == Field::Tag) {
106 let tags: Vec<String> = tag_filters
107 .iter()
108 .filter_map(|f| match &f.value {
109 Value::Text(t) => Some(t.clone()),
110 _ => None,
111 })
112 .collect();
113 return Ok(QueryPlan::TagScan {
114 tags,
115 filters: Vec::new(),
116 limit: Some(limit),
117 });
118 }
119
120 let time_filters: Vec<&Filter> = recall
122 .filters
123 .iter()
124 .filter(|f| {
125 (f.field == Field::Created || f.field == Field::Accessed)
126 && matches!(
127 f.op,
128 Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
129 )
130 })
131 .collect();
132
133 if !time_filters.is_empty() {
134 let remaining: Vec<Filter> = recall
135 .filters
136 .iter()
137 .filter(|f| {
138 !((f.field == Field::Created || f.field == Field::Accessed)
139 && matches!(
140 f.op,
141 Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
142 ))
143 })
144 .cloned()
145 .collect();
146
147 let mut start: Timestamp = 0;
148 let mut end: Timestamp = u64::MAX;
149 for f in &time_filters {
150 if let Value::Text(ref s) = f.value {
151 let _ = s; }
155 if let Value::Integer(ts) = f.value {
156 let ts = ts as u64;
157 match f.op {
158 Operator::Gt | Operator::Gte => start = ts,
159 Operator::Lt | Operator::Lte => end = ts,
160 _ => {}
161 }
162 }
163 }
164
165 return Ok(QueryPlan::TemporalScan {
166 start,
167 end,
168 filters: remaining,
169 });
170 }
171
172 Ok(QueryPlan::TagScan {
174 tags: Vec::new(),
175 filters: recall.filters.clone(),
176 limit: Some(limit),
177 })
178}
179
180fn plan_relate(relate: &RelateStatement) -> MenteResult<QueryPlan> {
181 Ok(QueryPlan::EdgeInsert {
182 source: relate.source,
183 target: relate.target,
184 edge_type: relate.edge_type,
185 weight: relate.weight.unwrap_or(1.0),
186 })
187}
188
189fn plan_traverse(trav: &TraverseStatement) -> MenteResult<QueryPlan> {
190 Ok(QueryPlan::GraphTraversal {
191 start: trav.start,
192 depth: trav.depth,
193 edge_types: trav.edge_filter.clone().unwrap_or_default(),
194 })
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::lexer::tokenize;
201 use crate::parser::Parser;
202 use uuid::Uuid;
203
204 fn plan_mql(input: &str) -> QueryPlan {
205 let tokens = tokenize(input).unwrap();
206 let stmt = Parser::parse(&tokens).unwrap();
207 plan(&stmt).unwrap()
208 }
209
210 #[test]
211 fn test_near_produces_vector_search() {
212 let qp = plan_mql("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 5");
213 match qp {
214 QueryPlan::VectorSearch { query, k, .. } => {
215 assert_eq!(query, vec![0.1, 0.2, 0.3]);
216 assert_eq!(k, 5);
217 }
218 _ => panic!("expected VectorSearch, got {:?}", qp),
219 }
220 }
221
222 #[test]
223 fn test_similar_to_produces_vector_search() {
224 let qp = plan_mql(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#);
225 match qp {
226 QueryPlan::VectorSearch { k, .. } => {
227 assert_eq!(k, 10);
228 }
229 _ => panic!("expected VectorSearch, got {:?}", qp),
230 }
231 }
232
233 #[test]
234 fn test_tag_filter_produces_tag_scan() {
235 let qp = plan_mql(r#"RECALL memories WHERE tag = "backend" LIMIT 5"#);
236 match qp {
237 QueryPlan::TagScan { tags, limit, .. } => {
238 assert_eq!(tags, vec!["backend".to_string()]);
239 assert_eq!(limit, Some(5));
240 }
241 _ => panic!("expected TagScan, got {:?}", qp),
242 }
243 }
244
245 #[test]
246 fn test_forget_produces_delete() {
247 let qp = plan_mql("FORGET 550e8400-e29b-41d4-a716-446655440000");
248 match qp {
249 QueryPlan::Delete { id } => {
250 assert_eq!(
251 id,
252 "550e8400-e29b-41d4-a716-446655440000"
253 .parse::<Uuid>()
254 .unwrap()
255 );
256 }
257 _ => panic!("expected Delete, got {:?}", qp),
258 }
259 }
260
261 #[test]
262 fn test_traverse_produces_graph_traversal() {
263 let qp = plan_mql(
264 "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
265 );
266 match qp {
267 QueryPlan::GraphTraversal {
268 depth, edge_types, ..
269 } => {
270 assert_eq!(depth, 3);
271 assert_eq!(edge_types, vec![EdgeType::Caused]);
272 }
273 _ => panic!("expected GraphTraversal, got {:?}", qp),
274 }
275 }
276
277 #[test]
278 fn test_relate_produces_edge_insert() {
279 let qp = plan_mql(
280 "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.8",
281 );
282 match qp {
283 QueryPlan::EdgeInsert {
284 edge_type, weight, ..
285 } => {
286 assert_eq!(edge_type, EdgeType::Caused);
287 assert!((weight - 0.8).abs() < f32::EPSILON);
288 }
289 _ => panic!("expected EdgeInsert, got {:?}", qp),
290 }
291 }
292}