Skip to main content

mentedb_query/
planner.rs

1//! Query planner: converts a parsed `Statement` into a `QueryPlan`.
2
3use crate::ast::*;
4use mentedb_core::edge::EdgeType;
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9/// A physical query plan describing how to execute a query.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub enum QueryPlan {
12    VectorSearch {
13        query: Vec<f32>,
14        k: usize,
15        filters: Vec<Filter>,
16    },
17    TagScan {
18        tags: Vec<String>,
19        filters: Vec<Filter>,
20        limit: Option<usize>,
21    },
22    TemporalScan {
23        start: Timestamp,
24        end: Timestamp,
25        filters: Vec<Filter>,
26    },
27    GraphTraversal {
28        start: MemoryId,
29        depth: usize,
30        edge_types: Vec<EdgeType>,
31    },
32    PointLookup {
33        id: MemoryId,
34    },
35    EdgeInsert {
36        source: MemoryId,
37        target: MemoryId,
38        edge_type: EdgeType,
39        weight: f32,
40    },
41    Delete {
42        id: MemoryId,
43    },
44    Consolidate {
45        filters: Vec<Filter>,
46    },
47}
48
49const DEFAULT_LIMIT: usize = 20;
50
51/// Produce an execution plan from a parsed statement.
52pub fn plan(statement: &Statement) -> MenteResult<QueryPlan> {
53    match statement {
54        Statement::Recall(recall) => plan_recall(recall),
55        Statement::Relate(relate) => plan_relate(relate),
56        Statement::Forget(forget) => Ok(QueryPlan::Delete { id: forget.target }),
57        Statement::Consolidate(cons) => Ok(QueryPlan::Consolidate {
58            filters: cons.filters.clone(),
59        }),
60        Statement::Traverse(trav) => plan_traverse(trav),
61    }
62}
63
64fn plan_recall(recall: &RecallStatement) -> MenteResult<QueryPlan> {
65    let limit = recall.limit.unwrap_or(DEFAULT_LIMIT);
66
67    // If there's a NEAR clause or a SimilarTo filter, use vector search
68    if let Some(ref vec) = recall.near {
69        return Ok(QueryPlan::VectorSearch {
70            query: vec.clone(),
71            k: limit,
72            filters: recall.filters.clone(),
73        });
74    }
75
76    // Check for SimilarTo operator in filters — implies vector search via text embedding
77    if let Some(sim_filter) = recall.filters.iter().find(|f| f.op == Operator::SimilarTo) {
78        if let Value::Text(ref _text) = sim_filter.value {
79            // The text will need to be embedded at execution time; we still emit VectorSearch
80            // with an empty query vec — the executor is responsible for embedding the text.
81            let remaining: Vec<Filter> = recall
82                .filters
83                .iter()
84                .filter(|f| f.op != Operator::SimilarTo)
85                .cloned()
86                .collect();
87            return Ok(QueryPlan::VectorSearch {
88                query: Vec::new(), // placeholder — executor embeds text
89                k: limit,
90                filters: remaining,
91            });
92        }
93        // SimilarTo with non-text value doesn't make sense
94        return Err(MenteError::Query(
95            "~> operator requires a text value on the right-hand side".into(),
96        ));
97    }
98
99    // If only tag filters, use TagScan
100    let tag_filters: Vec<&Filter> = recall
101        .filters
102        .iter()
103        .filter(|f| f.field == Field::Tag)
104        .collect();
105    if !tag_filters.is_empty() && recall.filters.iter().all(|f| f.field == Field::Tag) {
106        let tags: Vec<String> = tag_filters
107            .iter()
108            .filter_map(|f| match &f.value {
109                Value::Text(t) => Some(t.clone()),
110                _ => None,
111            })
112            .collect();
113        return Ok(QueryPlan::TagScan {
114            tags,
115            filters: Vec::new(),
116            limit: Some(limit),
117        });
118    }
119
120    // If time-range filters exist (created or accessed with range ops), use TemporalScan
121    let time_filters: Vec<&Filter> = recall
122        .filters
123        .iter()
124        .filter(|f| {
125            (f.field == Field::Created || f.field == Field::Accessed)
126                && matches!(
127                    f.op,
128                    Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
129                )
130        })
131        .collect();
132
133    if !time_filters.is_empty() {
134        let remaining: Vec<Filter> = recall
135            .filters
136            .iter()
137            .filter(|f| {
138                !((f.field == Field::Created || f.field == Field::Accessed)
139                    && matches!(
140                        f.op,
141                        Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
142                    ))
143            })
144            .cloned()
145            .collect();
146
147        let mut start: Timestamp = 0;
148        let mut end: Timestamp = u64::MAX;
149        for f in &time_filters {
150            if let Value::Text(ref s) = f.value {
151                // Simple heuristic: treat text timestamps as orderable strings for now.
152                // A real implementation would parse dates. We use 0/MAX as placeholders.
153                let _ = s; // acknowledged
154            }
155            if let Value::Integer(ts) = f.value {
156                let ts = ts as u64;
157                match f.op {
158                    Operator::Gt | Operator::Gte => start = ts,
159                    Operator::Lt | Operator::Lte => end = ts,
160                    _ => {}
161                }
162            }
163        }
164
165        return Ok(QueryPlan::TemporalScan {
166            start,
167            end,
168            filters: remaining,
169        });
170    }
171
172    // Fallback: tag scan with no tags (full scan with filters)
173    Ok(QueryPlan::TagScan {
174        tags: Vec::new(),
175        filters: recall.filters.clone(),
176        limit: Some(limit),
177    })
178}
179
180fn plan_relate(relate: &RelateStatement) -> MenteResult<QueryPlan> {
181    Ok(QueryPlan::EdgeInsert {
182        source: relate.source,
183        target: relate.target,
184        edge_type: relate.edge_type,
185        weight: relate.weight.unwrap_or(1.0),
186    })
187}
188
189fn plan_traverse(trav: &TraverseStatement) -> MenteResult<QueryPlan> {
190    Ok(QueryPlan::GraphTraversal {
191        start: trav.start,
192        depth: trav.depth,
193        edge_types: trav.edge_filter.clone().unwrap_or_default(),
194    })
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::lexer::tokenize;
201    use crate::parser::Parser;
202
203    fn plan_mql(input: &str) -> QueryPlan {
204        let tokens = tokenize(input).unwrap();
205        let stmt = Parser::parse(&tokens).unwrap();
206        plan(&stmt).unwrap()
207    }
208
209    #[test]
210    fn test_near_produces_vector_search() {
211        let qp = plan_mql("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 5");
212        match qp {
213            QueryPlan::VectorSearch { query, k, .. } => {
214                assert_eq!(query, vec![0.1, 0.2, 0.3]);
215                assert_eq!(k, 5);
216            }
217            _ => panic!("expected VectorSearch, got {:?}", qp),
218        }
219    }
220
221    #[test]
222    fn test_similar_to_produces_vector_search() {
223        let qp = plan_mql(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#);
224        match qp {
225            QueryPlan::VectorSearch { k, .. } => {
226                assert_eq!(k, 10);
227            }
228            _ => panic!("expected VectorSearch, got {:?}", qp),
229        }
230    }
231
232    #[test]
233    fn test_tag_filter_produces_tag_scan() {
234        let qp = plan_mql(r#"RECALL memories WHERE tag = "backend" LIMIT 5"#);
235        match qp {
236            QueryPlan::TagScan { tags, limit, .. } => {
237                assert_eq!(tags, vec!["backend".to_string()]);
238                assert_eq!(limit, Some(5));
239            }
240            _ => panic!("expected TagScan, got {:?}", qp),
241        }
242    }
243
244    #[test]
245    fn test_forget_produces_delete() {
246        let qp = plan_mql("FORGET 550e8400-e29b-41d4-a716-446655440000");
247        match qp {
248            QueryPlan::Delete { id } => {
249                assert_eq!(
250                    id,
251                    "550e8400-e29b-41d4-a716-446655440000"
252                        .parse::<MemoryId>()
253                        .unwrap()
254                );
255            }
256            _ => panic!("expected Delete, got {:?}", qp),
257        }
258    }
259
260    #[test]
261    fn test_traverse_produces_graph_traversal() {
262        let qp = plan_mql(
263            "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
264        );
265        match qp {
266            QueryPlan::GraphTraversal {
267                depth, edge_types, ..
268            } => {
269                assert_eq!(depth, 3);
270                assert_eq!(edge_types, vec![EdgeType::Caused]);
271            }
272            _ => panic!("expected GraphTraversal, got {:?}", qp),
273        }
274    }
275
276    #[test]
277    fn test_relate_produces_edge_insert() {
278        let qp = plan_mql(
279            "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.8",
280        );
281        match qp {
282            QueryPlan::EdgeInsert {
283                edge_type, weight, ..
284            } => {
285                assert_eq!(edge_type, EdgeType::Caused);
286                assert!((weight - 0.8).abs() < f32::EPSILON);
287            }
288            _ => panic!("expected EdgeInsert, got {:?}", qp),
289        }
290    }
291}