Skip to main content

contextdb_planner/
planner.rs

1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4    AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5    SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14    match stmt {
15        Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16            name: ct.name.clone(),
17            columns: ct.columns.clone(),
18            unique_constraints: ct.unique_constraints.clone(),
19            immutable: ct.immutable,
20            state_machine: ct.state_machine.clone(),
21            dag_edge_types: ct.dag_edge_types.clone(),
22            propagation_rules: extract_propagation_rules(ct)?,
23            retain: ct.retain.clone(),
24        })),
25        Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
26            table: at.table.clone(),
27            action: at.action.clone(),
28        })),
29        Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
30        Statement::CreateIndex(ci) => Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
31            name: ci.name.clone(),
32            table: ci.table.clone(),
33            columns: ci.columns.clone(),
34        })),
35        Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
36            table: i.table.clone(),
37            columns: i.columns.clone(),
38            values: i.values.clone(),
39            on_conflict: i.on_conflict.clone().map(Into::into),
40        })),
41        Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
42            table: d.table.clone(),
43            where_clause: d.where_clause.clone(),
44        })),
45        Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
46            table: u.table.clone(),
47            assignments: u.assignments.clone(),
48            where_clause: u.where_clause.clone(),
49        })),
50        Statement::Select(sel) => plan_select(sel),
51        Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
52        Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
53        Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
54        Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
55        Statement::SetSyncConflictPolicy(policy) => {
56            Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
57        }
58        Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
59        Statement::Begin | Statement::Commit | Statement::Rollback => {
60            Ok(PhysicalPlan::Pipeline(vec![]))
61        }
62    }
63}
64
65fn extract_propagation_rules(
66    ct: &contextdb_parser::ast::CreateTable,
67) -> Result<Vec<PropagationRule>> {
68    let mut rules = Vec::new();
69
70    for column in &ct.columns {
71        if let Some(fk) = &column.references {
72            for rule in &fk.propagation_rules {
73                if let AstPropagationRule::FkState {
74                    trigger_state,
75                    target_state,
76                    max_depth,
77                    abort_on_failure,
78                } = rule
79                {
80                    rules.push(PropagationRule::ForeignKey {
81                        fk_column: column.name.clone(),
82                        referenced_table: fk.table.clone(),
83                        referenced_column: fk.column.clone(),
84                        trigger_state: trigger_state.clone(),
85                        target_state: target_state.clone(),
86                        max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
87                        abort_on_failure: *abort_on_failure,
88                    });
89                }
90            }
91        }
92    }
93
94    for rule in &ct.propagation_rules {
95        match rule {
96            AstPropagationRule::EdgeState {
97                edge_type,
98                direction,
99                trigger_state,
100                target_state,
101                max_depth,
102                abort_on_failure,
103            } => {
104                let direction = match direction.to_ascii_uppercase().as_str() {
105                    "OUTGOING" => Direction::Outgoing,
106                    "INCOMING" => Direction::Incoming,
107                    "BOTH" => Direction::Both,
108                    other => {
109                        return Err(Error::PlanError(format!(
110                            "invalid edge direction in propagation rule: {}",
111                            other
112                        )));
113                    }
114                };
115                rules.push(PropagationRule::Edge {
116                    edge_type: edge_type.clone(),
117                    direction,
118                    trigger_state: trigger_state.clone(),
119                    target_state: target_state.clone(),
120                    max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
121                    abort_on_failure: *abort_on_failure,
122                });
123            }
124            AstPropagationRule::VectorExclusion { trigger_state } => {
125                rules.push(PropagationRule::VectorExclusion {
126                    trigger_state: trigger_state.clone(),
127                });
128            }
129            AstPropagationRule::FkState { .. } => {}
130        }
131    }
132
133    Ok(rules)
134}
135
136fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
137    let mut cte_env = HashMap::new();
138
139    for cte in &sel.ctes {
140        match cte {
141            Cte::MatchCte { name, match_clause } => {
142                cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
143            }
144            Cte::SqlCte { name, query } => {
145                cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
146            }
147        }
148    }
149
150    plan_select_body(&sel.body, &cte_env)
151}
152
153fn plan_select_body(
154    body: &SelectBody,
155    cte_env: &HashMap<String, PhysicalPlan>,
156) -> Result<PhysicalPlan> {
157    let graph_from = body
158        .from
159        .iter()
160        .find(|f| matches!(f, FromItem::GraphTable { .. }));
161
162    let mut current = if let Some(from_item) = graph_from {
163        graph_plan_from_from_item(from_item, cte_env)?
164    } else {
165        let from_item = body.from.iter().find_map(|item| match item {
166            FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
167            FromItem::GraphTable { .. } => None,
168        });
169
170        match from_item {
171            Some((from_table, from_alias)) => {
172                if let Some(cte_plan) = cte_env.get(&from_table) {
173                    let mut cte_plan = cte_plan.clone();
174                    if body.joins.is_empty()
175                        && let Some(where_clause) = &body.where_clause
176                    {
177                        cte_plan = PhysicalPlan::Filter {
178                            input: Box::new(cte_plan),
179                            predicate: where_clause.clone(),
180                        };
181                    }
182                    cte_plan
183                } else {
184                    PhysicalPlan::Scan {
185                        table: from_table,
186                        alias: from_alias.clone(),
187                        filter: if body.joins.is_empty() {
188                            body.where_clause.clone()
189                        } else {
190                            None
191                        },
192                    }
193                }
194            }
195            None => PhysicalPlan::Scan {
196                table: "dual".to_string(),
197                alias: None,
198                filter: None,
199            },
200        }
201    };
202
203    if !body.joins.is_empty() {
204        let left_alias = body.from.iter().find_map(|item| match item {
205            FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
206            FromItem::GraphTable { .. } => None,
207        });
208
209        for join in &body.joins {
210            let right = if let Some(cte_plan) = cte_env.get(&join.table) {
211                cte_plan.clone()
212            } else {
213                PhysicalPlan::Scan {
214                    table: join.table.clone(),
215                    alias: join.alias.clone(),
216                    filter: None,
217                }
218            };
219
220            current = PhysicalPlan::Join {
221                left: Box::new(current),
222                right: Box::new(right),
223                condition: join.on.clone(),
224                join_type: match join.join_type {
225                    contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
226                    contextdb_parser::ast::JoinType::Left => JoinType::Left,
227                },
228                left_alias: left_alias.clone(),
229                right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
230            };
231        }
232
233        if let Some(where_clause) = &body.where_clause {
234            current = PhysicalPlan::Filter {
235                input: Box::new(current),
236                predicate: where_clause.clone(),
237            };
238        }
239    }
240
241    let uses_vector_search = body
242        .order_by
243        .first()
244        .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
245
246    if let Some(order) = body.order_by.first()
247        && matches!(order.direction, SortDirection::CosineDistance)
248    {
249        let k = body.limit.ok_or(Error::UnboundedVectorSearch)?;
250        let vector_table = vector_base_table(&current)?.ok_or_else(|| {
251            Error::PlanError("unable to resolve physical vector source table".to_string())
252        })?;
253        current = PhysicalPlan::VectorSearch {
254            table: vector_table,
255            column: "embedding".to_string(),
256            query_expr: order.expr.clone(),
257            k,
258            candidates: Some(Box::new(current)),
259        };
260    }
261
262    if !body.order_by.is_empty() && !uses_vector_search {
263        current = PhysicalPlan::Sort {
264            input: Box::new(current),
265            keys: body
266                .order_by
267                .iter()
268                .map(|item| SortKey {
269                    expr: item.expr.clone(),
270                    direction: item.direction,
271                })
272                .collect(),
273        };
274    }
275
276    let is_select_star = matches!(
277        body.columns.as_slice(),
278        [contextdb_parser::ast::SelectColumn {
279            expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
280            alias: None
281        }] if column == "*"
282    );
283    if !is_select_star {
284        current = PhysicalPlan::Project {
285            input: Box::new(current),
286            columns: body
287                .columns
288                .iter()
289                .map(|column| ProjectColumn {
290                    expr: column.expr.clone(),
291                    alias: column.alias.clone(),
292                })
293                .collect(),
294        };
295    }
296
297    if body.distinct {
298        current = PhysicalPlan::Distinct {
299            input: Box::new(current),
300        };
301    }
302
303    if let Some(limit) = body.limit
304        && !uses_vector_search
305    {
306        current = PhysicalPlan::Limit {
307            input: Box::new(current),
308            count: limit,
309        };
310    }
311
312    Ok(current)
313}
314
315fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
316    match plan {
317        PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
318            Ok(Some(table.clone()))
319        }
320        PhysicalPlan::Filter { input, .. }
321        | PhysicalPlan::Project { input, .. }
322        | PhysicalPlan::Distinct { input }
323        | PhysicalPlan::Sort { input, .. }
324        | PhysicalPlan::Limit { input, .. }
325        | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
326        PhysicalPlan::Join { left, right, .. } => {
327            let left_table = vector_base_table(left)?;
328            let right_table = vector_base_table(right)?;
329            match (left_table, right_table) {
330                (Some(left), Some(right)) if left == right => Ok(Some(left)),
331                (Some(_), Some(_)) => Err(Error::PlanError(
332                    "ambiguous physical vector source table in join".to_string(),
333                )),
334                (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
335                (None, None) => Ok(None),
336            }
337        }
338        PhysicalPlan::Pipeline(plans) => {
339            for plan in plans.iter().rev() {
340                if let Some(table) = vector_base_table(plan)? {
341                    return Ok(Some(table));
342                }
343            }
344            Ok(None)
345        }
346        PhysicalPlan::GraphBfs { .. }
347        | PhysicalPlan::CteRef { .. }
348        | PhysicalPlan::Union { .. } => Ok(None),
349        _ => Ok(None),
350    }
351}
352
353fn graph_plan_from_from_item(
354    from_item: &FromItem,
355    cte_env: &HashMap<String, PhysicalPlan>,
356) -> Result<PhysicalPlan> {
357    match from_item {
358        FromItem::GraphTable {
359            match_clause,
360            columns,
361            ..
362        } => {
363            let bfs = graph_bfs_from_match(match_clause, cte_env)?;
364            if columns.is_empty() {
365                Ok(bfs)
366            } else {
367                Ok(PhysicalPlan::Project {
368                    input: Box::new(bfs),
369                    columns: columns
370                        .iter()
371                        .map(|c| ProjectColumn {
372                            expr: c.expr.clone(),
373                            alias: Some(c.alias.clone()),
374                        })
375                        .collect(),
376                })
377            }
378        }
379        FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
380            table: name.clone(),
381            alias: None,
382            filter: None,
383        }),
384    }
385}
386
387fn graph_bfs_from_match(
388    match_clause: &contextdb_parser::ast::MatchClause,
389    cte_env: &HashMap<String, PhysicalPlan>,
390) -> Result<PhysicalPlan> {
391    let steps = match_clause
392        .pattern
393        .edges
394        .iter()
395        .map(|step| {
396            let max_depth = if step.max_hops == 0 {
397                DEFAULT_MATCH_DEPTH
398            } else {
399                step.max_hops
400            };
401            if max_depth > ENGINE_MAX_BFS_DEPTH {
402                return Err(Error::BfsDepthExceeded(max_depth));
403            }
404
405            Ok(GraphStepPlan {
406                edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
407                direction: match step.direction {
408                    contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
409                    contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
410                    contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
411                },
412                min_depth: step.min_hops.max(1),
413                max_depth,
414                target_alias: step.target.alias.clone(),
415            })
416        })
417        .collect::<Result<Vec<_>>>()?;
418    if steps.is_empty() {
419        return Err(Error::PlanError(
420            "MATCH must include at least one edge".into(),
421        ));
422    }
423
424    Ok(PhysicalPlan::GraphBfs {
425        start_alias: match_clause.pattern.start.alias.clone(),
426        start_expr: extract_graph_start_expr(match_clause)?,
427        start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
428        steps,
429        filter: match_clause.where_clause.clone(),
430    })
431}
432
433fn extract_graph_start_candidates(
434    match_clause: &MatchClause,
435    cte_env: &HashMap<String, PhysicalPlan>,
436) -> Result<Option<Box<PhysicalPlan>>> {
437    let Some(where_clause) = &match_clause.where_clause else {
438        return Ok(None);
439    };
440    find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
441}
442
443fn find_graph_start_candidates(
444    expr: &Expr,
445    start_alias: &str,
446    cte_env: &HashMap<String, PhysicalPlan>,
447) -> Result<Option<Box<PhysicalPlan>>> {
448    match expr {
449        Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
450            Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
451        }
452        Expr::BinaryOp { left, right, .. } => {
453            if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
454                return Ok(Some(plan));
455            }
456            find_graph_start_candidates(right, start_alias, cte_env)
457        }
458        Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
459        _ => Ok(None),
460    }
461}
462
463fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
464    let start_alias = &match_clause.pattern.start.alias;
465    if let Some(where_clause) = &match_clause.where_clause
466        && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
467    {
468        return Ok(expr);
469    }
470
471    Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
472        table: None,
473        column: start_alias.clone(),
474    }))
475}
476
477fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
478    match expr {
479        Expr::BinaryOp {
480            left,
481            op: BinOp::Eq,
482            right,
483        } => {
484            if is_graph_start_id_ref(left, start_alias) {
485                Some((**right).clone())
486            } else if is_graph_start_id_ref(right, start_alias) {
487                Some((**left).clone())
488            } else {
489                None
490            }
491        }
492        Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
493            .or_else(|| find_graph_start_expr(right, start_alias)),
494        Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
495        _ => None,
496    }
497}
498
499fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
500    matches!(
501        expr,
502        Expr::Column(contextdb_parser::ast::ColumnRef {
503            table: Some(table),
504            column
505        }) if table == start_alias && column == "id"
506    )
507}