Skip to main content

mentedb_query/
planner.rs

1//! Query planner — converts a parsed `Statement` into a `QueryPlan`.
2
3use crate::ast::*;
4use mentedb_core::edge::EdgeType;
5use mentedb_core::error::{MenteError, MenteResult};
6use mentedb_core::types::{MemoryId, Timestamp};
7use serde::{Deserialize, Serialize};
8
9/// A physical query plan describing how to execute a query.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub enum QueryPlan {
12    VectorSearch {
13        query: Vec<f32>,
14        k: usize,
15        filters: Vec<Filter>,
16    },
17    TagScan {
18        tags: Vec<String>,
19        filters: Vec<Filter>,
20        limit: Option<usize>,
21    },
22    TemporalScan {
23        start: Timestamp,
24        end: Timestamp,
25        filters: Vec<Filter>,
26    },
27    GraphTraversal {
28        start: MemoryId,
29        depth: usize,
30        edge_types: Vec<EdgeType>,
31    },
32    PointLookup {
33        id: MemoryId,
34    },
35    EdgeInsert {
36        source: MemoryId,
37        target: MemoryId,
38        edge_type: EdgeType,
39        weight: f32,
40    },
41    Delete {
42        id: MemoryId,
43    },
44    Consolidate {
45        filters: Vec<Filter>,
46    },
47}
48
49const DEFAULT_LIMIT: usize = 20;
50
51/// Produce an execution plan from a parsed statement.
52pub fn plan(statement: &Statement) -> MenteResult<QueryPlan> {
53    match statement {
54        Statement::Recall(recall) => plan_recall(recall),
55        Statement::Relate(relate) => plan_relate(relate),
56        Statement::Forget(forget) => Ok(QueryPlan::Delete { id: forget.target }),
57        Statement::Consolidate(cons) => Ok(QueryPlan::Consolidate {
58            filters: cons.filters.clone(),
59        }),
60        Statement::Traverse(trav) => plan_traverse(trav),
61    }
62}
63
64fn plan_recall(recall: &RecallStatement) -> MenteResult<QueryPlan> {
65    let limit = recall.limit.unwrap_or(DEFAULT_LIMIT);
66
67    // If there's a NEAR clause or a SimilarTo filter, use vector search
68    if let Some(ref vec) = recall.near {
69        return Ok(QueryPlan::VectorSearch {
70            query: vec.clone(),
71            k: limit,
72            filters: recall.filters.clone(),
73        });
74    }
75
76    // Check for SimilarTo operator in filters — implies vector search via text embedding
77    if let Some(sim_filter) = recall.filters.iter().find(|f| f.op == Operator::SimilarTo) {
78        if let Value::Text(ref _text) = sim_filter.value {
79            // The text will need to be embedded at execution time; we still emit VectorSearch
80            // with an empty query vec — the executor is responsible for embedding the text.
81            let remaining: Vec<Filter> = recall
82                .filters
83                .iter()
84                .filter(|f| f.op != Operator::SimilarTo)
85                .cloned()
86                .collect();
87            return Ok(QueryPlan::VectorSearch {
88                query: Vec::new(), // placeholder — executor embeds text
89                k: limit,
90                filters: remaining,
91            });
92        }
93        // SimilarTo with non-text value doesn't make sense
94        return Err(MenteError::Query(
95            "~> operator requires a text value on the right-hand side".into(),
96        ));
97    }
98
99    // If only tag filters, use TagScan
100    let tag_filters: Vec<&Filter> = recall
101        .filters
102        .iter()
103        .filter(|f| f.field == Field::Tag)
104        .collect();
105    if !tag_filters.is_empty() && recall.filters.iter().all(|f| f.field == Field::Tag) {
106        let tags: Vec<String> = tag_filters
107            .iter()
108            .filter_map(|f| match &f.value {
109                Value::Text(t) => Some(t.clone()),
110                _ => None,
111            })
112            .collect();
113        return Ok(QueryPlan::TagScan {
114            tags,
115            filters: Vec::new(),
116            limit: Some(limit),
117        });
118    }
119
120    // If time-range filters exist (created or accessed with range ops), use TemporalScan
121    let time_filters: Vec<&Filter> = recall
122        .filters
123        .iter()
124        .filter(|f| {
125            (f.field == Field::Created || f.field == Field::Accessed)
126                && matches!(
127                    f.op,
128                    Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
129                )
130        })
131        .collect();
132
133    if !time_filters.is_empty() {
134        let remaining: Vec<Filter> = recall
135            .filters
136            .iter()
137            .filter(|f| {
138                !((f.field == Field::Created || f.field == Field::Accessed)
139                    && matches!(
140                        f.op,
141                        Operator::Gt | Operator::Lt | Operator::Gte | Operator::Lte
142                    ))
143            })
144            .cloned()
145            .collect();
146
147        let mut start: Timestamp = 0;
148        let mut end: Timestamp = u64::MAX;
149        for f in &time_filters {
150            if let Value::Text(ref s) = f.value {
151                // Simple heuristic: treat text timestamps as orderable strings for now.
152                // A real implementation would parse dates. We use 0/MAX as placeholders.
153                let _ = s; // acknowledged
154            }
155            if let Value::Integer(ts) = f.value {
156                let ts = ts as u64;
157                match f.op {
158                    Operator::Gt | Operator::Gte => start = ts,
159                    Operator::Lt | Operator::Lte => end = ts,
160                    _ => {}
161                }
162            }
163        }
164
165        return Ok(QueryPlan::TemporalScan {
166            start,
167            end,
168            filters: remaining,
169        });
170    }
171
172    // Fallback: tag scan with no tags (full scan with filters)
173    Ok(QueryPlan::TagScan {
174        tags: Vec::new(),
175        filters: recall.filters.clone(),
176        limit: Some(limit),
177    })
178}
179
180fn plan_relate(relate: &RelateStatement) -> MenteResult<QueryPlan> {
181    Ok(QueryPlan::EdgeInsert {
182        source: relate.source,
183        target: relate.target,
184        edge_type: relate.edge_type,
185        weight: relate.weight.unwrap_or(1.0),
186    })
187}
188
189fn plan_traverse(trav: &TraverseStatement) -> MenteResult<QueryPlan> {
190    Ok(QueryPlan::GraphTraversal {
191        start: trav.start,
192        depth: trav.depth,
193        edge_types: trav.edge_filter.clone().unwrap_or_default(),
194    })
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::lexer::tokenize;
201    use crate::parser::Parser;
202    use uuid::Uuid;
203
204    fn plan_mql(input: &str) -> QueryPlan {
205        let tokens = tokenize(input).unwrap();
206        let stmt = Parser::parse(&tokens).unwrap();
207        plan(&stmt).unwrap()
208    }
209
210    #[test]
211    fn test_near_produces_vector_search() {
212        let qp = plan_mql("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 5");
213        match qp {
214            QueryPlan::VectorSearch { query, k, .. } => {
215                assert_eq!(query, vec![0.1, 0.2, 0.3]);
216                assert_eq!(k, 5);
217            }
218            _ => panic!("expected VectorSearch, got {:?}", qp),
219        }
220    }
221
222    #[test]
223    fn test_similar_to_produces_vector_search() {
224        let qp = plan_mql(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#);
225        match qp {
226            QueryPlan::VectorSearch { k, .. } => {
227                assert_eq!(k, 10);
228            }
229            _ => panic!("expected VectorSearch, got {:?}", qp),
230        }
231    }
232
233    #[test]
234    fn test_tag_filter_produces_tag_scan() {
235        let qp = plan_mql(r#"RECALL memories WHERE tag = "backend" LIMIT 5"#);
236        match qp {
237            QueryPlan::TagScan { tags, limit, .. } => {
238                assert_eq!(tags, vec!["backend".to_string()]);
239                assert_eq!(limit, Some(5));
240            }
241            _ => panic!("expected TagScan, got {:?}", qp),
242        }
243    }
244
245    #[test]
246    fn test_forget_produces_delete() {
247        let qp = plan_mql("FORGET 550e8400-e29b-41d4-a716-446655440000");
248        match qp {
249            QueryPlan::Delete { id } => {
250                assert_eq!(
251                    id,
252                    "550e8400-e29b-41d4-a716-446655440000"
253                        .parse::<Uuid>()
254                        .unwrap()
255                );
256            }
257            _ => panic!("expected Delete, got {:?}", qp),
258        }
259    }
260
261    #[test]
262    fn test_traverse_produces_graph_traversal() {
263        let qp = plan_mql(
264            "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
265        );
266        match qp {
267            QueryPlan::GraphTraversal {
268                depth, edge_types, ..
269            } => {
270                assert_eq!(depth, 3);
271                assert_eq!(edge_types, vec![EdgeType::Caused]);
272            }
273            _ => panic!("expected GraphTraversal, got {:?}", qp),
274        }
275    }
276
277    #[test]
278    fn test_relate_produces_edge_insert() {
279        let qp = plan_mql(
280            "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.8",
281        );
282        match qp {
283            QueryPlan::EdgeInsert {
284                edge_type, weight, ..
285            } => {
286                assert_eq!(edge_type, EdgeType::Caused);
287                assert!((weight - 0.8).abs() < f32::EPSILON);
288            }
289            _ => panic!("expected EdgeInsert, got {:?}", qp),
290        }
291    }
292}