Skip to main content

contextdb_planner/
planner.rs

1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4    AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5    SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14    match stmt {
15        Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16            name: ct.name.clone(),
17            columns: ct.columns.clone(),
18            unique_constraints: ct.unique_constraints.clone(),
19            immutable: ct.immutable,
20            state_machine: ct.state_machine.clone(),
21            dag_edge_types: ct.dag_edge_types.clone(),
22            propagation_rules: extract_propagation_rules(ct)?,
23            retain: ct.retain.clone(),
24        })),
25        Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
26            table: at.table.clone(),
27            action: at.action.clone(),
28        })),
29        Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
30        Statement::CreateIndex(ci) => {
31            let mut columns = Vec::with_capacity(ci.columns.len());
32            for (col, dir) in &ci.columns {
33                columns.push((col.clone(), map_parser_to_core_sort_direction(*dir)?));
34            }
35            Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
36                name: ci.name.clone(),
37                table: ci.table.clone(),
38                columns,
39            }))
40        }
41        Statement::DropIndex(di) => Ok(PhysicalPlan::DropIndex(DropIndexPlan {
42            name: di.name.clone(),
43            table: di.table.clone(),
44            if_exists: di.if_exists,
45        })),
46        Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
47            table: i.table.clone(),
48            columns: i.columns.clone(),
49            values: i.values.clone(),
50            on_conflict: i.on_conflict.clone().map(Into::into),
51        })),
52        Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
53            table: d.table.clone(),
54            where_clause: d.where_clause.clone(),
55        })),
56        Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
57            table: u.table.clone(),
58            assignments: u.assignments.clone(),
59            where_clause: u.where_clause.clone(),
60        })),
61        Statement::Select(sel) => plan_select(sel),
62        Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
63        Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
64        Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
65        Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
66        Statement::SetSyncConflictPolicy(policy) => {
67            Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
68        }
69        Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
70        Statement::Begin | Statement::Commit | Statement::Rollback => {
71            Ok(PhysicalPlan::Pipeline(vec![]))
72        }
73    }
74}
75
76fn extract_propagation_rules(
77    ct: &contextdb_parser::ast::CreateTable,
78) -> Result<Vec<PropagationRule>> {
79    let mut rules = Vec::new();
80
81    for column in &ct.columns {
82        if let Some(fk) = &column.references {
83            for rule in &fk.propagation_rules {
84                if let AstPropagationRule::FkState {
85                    trigger_state,
86                    target_state,
87                    max_depth,
88                    abort_on_failure,
89                } = rule
90                {
91                    rules.push(PropagationRule::ForeignKey {
92                        fk_column: column.name.clone(),
93                        referenced_table: fk.table.clone(),
94                        referenced_column: fk.column.clone(),
95                        trigger_state: trigger_state.clone(),
96                        target_state: target_state.clone(),
97                        max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
98                        abort_on_failure: *abort_on_failure,
99                    });
100                }
101            }
102        }
103    }
104
105    for rule in &ct.propagation_rules {
106        match rule {
107            AstPropagationRule::EdgeState {
108                edge_type,
109                direction,
110                trigger_state,
111                target_state,
112                max_depth,
113                abort_on_failure,
114            } => {
115                let direction = match direction.to_ascii_uppercase().as_str() {
116                    "OUTGOING" => Direction::Outgoing,
117                    "INCOMING" => Direction::Incoming,
118                    "BOTH" => Direction::Both,
119                    other => {
120                        return Err(Error::PlanError(format!(
121                            "invalid edge direction in propagation rule: {}",
122                            other
123                        )));
124                    }
125                };
126                rules.push(PropagationRule::Edge {
127                    edge_type: edge_type.clone(),
128                    direction,
129                    trigger_state: trigger_state.clone(),
130                    target_state: target_state.clone(),
131                    max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
132                    abort_on_failure: *abort_on_failure,
133                });
134            }
135            AstPropagationRule::VectorExclusion { trigger_state } => {
136                rules.push(PropagationRule::VectorExclusion {
137                    trigger_state: trigger_state.clone(),
138                });
139            }
140            AstPropagationRule::FkState { .. } => {}
141        }
142    }
143
144    Ok(rules)
145}
146
147fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
148    let mut cte_env = HashMap::new();
149
150    for cte in &sel.ctes {
151        match cte {
152            Cte::MatchCte { name, match_clause } => {
153                cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
154            }
155            Cte::SqlCte { name, query } => {
156                cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
157            }
158        }
159    }
160
161    plan_select_body(&sel.body, &cte_env)
162}
163
164fn plan_select_body(
165    body: &SelectBody,
166    cte_env: &HashMap<String, PhysicalPlan>,
167) -> Result<PhysicalPlan> {
168    let graph_from = body
169        .from
170        .iter()
171        .find(|f| matches!(f, FromItem::GraphTable { .. }));
172
173    let mut current = if let Some(from_item) = graph_from {
174        graph_plan_from_from_item(from_item, cte_env)?
175    } else {
176        let from_item = body.from.iter().find_map(|item| match item {
177            FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
178            FromItem::GraphTable { .. } => None,
179        });
180
181        match from_item {
182            Some((from_table, from_alias)) => {
183                if let Some(cte_plan) = cte_env.get(&from_table) {
184                    let mut cte_plan = cte_plan.clone();
185                    if body.joins.is_empty()
186                        && let Some(where_clause) = &body.where_clause
187                    {
188                        cte_plan = PhysicalPlan::Filter {
189                            input: Box::new(cte_plan),
190                            predicate: where_clause.clone(),
191                        };
192                    }
193                    cte_plan
194                } else {
195                    PhysicalPlan::Scan {
196                        table: from_table,
197                        alias: from_alias.clone(),
198                        filter: if body.joins.is_empty() {
199                            body.where_clause.clone()
200                        } else {
201                            None
202                        },
203                    }
204                }
205            }
206            None => PhysicalPlan::Scan {
207                table: "dual".to_string(),
208                alias: None,
209                filter: None,
210            },
211        }
212    };
213
214    if !body.joins.is_empty() {
215        let left_alias = body.from.iter().find_map(|item| match item {
216            FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
217            FromItem::GraphTable { .. } => None,
218        });
219
220        for join in &body.joins {
221            let right = if let Some(cte_plan) = cte_env.get(&join.table) {
222                cte_plan.clone()
223            } else {
224                PhysicalPlan::Scan {
225                    table: join.table.clone(),
226                    alias: join.alias.clone(),
227                    filter: None,
228                }
229            };
230
231            current = PhysicalPlan::Join {
232                left: Box::new(current),
233                right: Box::new(right),
234                condition: join.on.clone(),
235                join_type: match join.join_type {
236                    contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
237                    contextdb_parser::ast::JoinType::Left => JoinType::Left,
238                },
239                left_alias: left_alias.clone(),
240                right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
241            };
242        }
243
244        if let Some(where_clause) = &body.where_clause {
245            current = PhysicalPlan::Filter {
246                input: Box::new(current),
247                predicate: where_clause.clone(),
248            };
249        }
250    }
251
252    let uses_vector_search = body
253        .order_by
254        .first()
255        .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
256
257    if let Some(order) = body.order_by.first()
258        && matches!(order.direction, SortDirection::CosineDistance)
259    {
260        let k = body.limit.ok_or(Error::UnboundedVectorSearch)?;
261        let vector_table = vector_base_table(&current)?.ok_or_else(|| {
262            Error::PlanError("unable to resolve physical vector source table".to_string())
263        })?;
264        current = PhysicalPlan::VectorSearch {
265            table: vector_table,
266            column: "embedding".to_string(),
267            query_expr: order.expr.clone(),
268            k,
269            candidates: Some(Box::new(current)),
270        };
271    }
272
273    if !body.order_by.is_empty() && !uses_vector_search {
274        current = PhysicalPlan::Sort {
275            input: Box::new(current),
276            keys: body
277                .order_by
278                .iter()
279                .map(|item| SortKey {
280                    expr: item.expr.clone(),
281                    direction: item.direction,
282                })
283                .collect(),
284        };
285    }
286
287    let is_select_star = matches!(
288        body.columns.as_slice(),
289        [contextdb_parser::ast::SelectColumn {
290            expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
291            alias: None
292        }] if column == "*"
293    );
294    if !is_select_star {
295        current = PhysicalPlan::Project {
296            input: Box::new(current),
297            columns: body
298                .columns
299                .iter()
300                .map(|column| ProjectColumn {
301                    expr: column.expr.clone(),
302                    alias: column.alias.clone(),
303                })
304                .collect(),
305        };
306    }
307
308    if body.distinct {
309        current = PhysicalPlan::Distinct {
310            input: Box::new(current),
311        };
312    }
313
314    if let Some(limit) = body.limit
315        && !uses_vector_search
316    {
317        current = PhysicalPlan::Limit {
318            input: Box::new(current),
319            count: limit,
320        };
321    }
322
323    Ok(current)
324}
325
326fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
327    match plan {
328        PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
329            Ok(Some(table.clone()))
330        }
331        PhysicalPlan::Filter { input, .. }
332        | PhysicalPlan::Project { input, .. }
333        | PhysicalPlan::Distinct { input }
334        | PhysicalPlan::Sort { input, .. }
335        | PhysicalPlan::Limit { input, .. }
336        | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
337        PhysicalPlan::Join { left, right, .. } => {
338            let left_table = vector_base_table(left)?;
339            let right_table = vector_base_table(right)?;
340            match (left_table, right_table) {
341                (Some(left), Some(right)) if left == right => Ok(Some(left)),
342                (Some(_), Some(_)) => Err(Error::PlanError(
343                    "ambiguous physical vector source table in join".to_string(),
344                )),
345                (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
346                (None, None) => Ok(None),
347            }
348        }
349        PhysicalPlan::Pipeline(plans) => {
350            for plan in plans.iter().rev() {
351                if let Some(table) = vector_base_table(plan)? {
352                    return Ok(Some(table));
353                }
354            }
355            Ok(None)
356        }
357        PhysicalPlan::GraphBfs { .. }
358        | PhysicalPlan::CteRef { .. }
359        | PhysicalPlan::Union { .. } => Ok(None),
360        _ => Ok(None),
361    }
362}
363
364fn graph_plan_from_from_item(
365    from_item: &FromItem,
366    cte_env: &HashMap<String, PhysicalPlan>,
367) -> Result<PhysicalPlan> {
368    match from_item {
369        FromItem::GraphTable {
370            match_clause,
371            columns,
372            ..
373        } => {
374            let bfs = graph_bfs_from_match(match_clause, cte_env)?;
375            if columns.is_empty() {
376                Ok(bfs)
377            } else {
378                Ok(PhysicalPlan::Project {
379                    input: Box::new(bfs),
380                    columns: columns
381                        .iter()
382                        .map(|c| ProjectColumn {
383                            expr: c.expr.clone(),
384                            alias: Some(c.alias.clone()),
385                        })
386                        .collect(),
387                })
388            }
389        }
390        FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
391            table: name.clone(),
392            alias: None,
393            filter: None,
394        }),
395    }
396}
397
398fn graph_bfs_from_match(
399    match_clause: &contextdb_parser::ast::MatchClause,
400    cte_env: &HashMap<String, PhysicalPlan>,
401) -> Result<PhysicalPlan> {
402    let steps = match_clause
403        .pattern
404        .edges
405        .iter()
406        .map(|step| {
407            let max_depth = if step.max_hops == 0 {
408                DEFAULT_MATCH_DEPTH
409            } else {
410                step.max_hops
411            };
412            if max_depth > ENGINE_MAX_BFS_DEPTH {
413                return Err(Error::BfsDepthExceeded(max_depth));
414            }
415
416            Ok(GraphStepPlan {
417                edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
418                direction: match step.direction {
419                    contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
420                    contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
421                    contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
422                },
423                min_depth: step.min_hops.max(1),
424                max_depth,
425                target_alias: step.target.alias.clone(),
426            })
427        })
428        .collect::<Result<Vec<_>>>()?;
429    if steps.is_empty() {
430        return Err(Error::PlanError(
431            "MATCH must include at least one edge".into(),
432        ));
433    }
434
435    Ok(PhysicalPlan::GraphBfs {
436        start_alias: match_clause.pattern.start.alias.clone(),
437        start_expr: extract_graph_start_expr(match_clause)?,
438        start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
439        steps,
440        filter: match_clause.where_clause.clone(),
441    })
442}
443
444fn extract_graph_start_candidates(
445    match_clause: &MatchClause,
446    cte_env: &HashMap<String, PhysicalPlan>,
447) -> Result<Option<Box<PhysicalPlan>>> {
448    let Some(where_clause) = &match_clause.where_clause else {
449        return Ok(None);
450    };
451    find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
452}
453
454fn find_graph_start_candidates(
455    expr: &Expr,
456    start_alias: &str,
457    cte_env: &HashMap<String, PhysicalPlan>,
458) -> Result<Option<Box<PhysicalPlan>>> {
459    match expr {
460        Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
461            Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
462        }
463        Expr::BinaryOp { left, right, .. } => {
464            if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
465                return Ok(Some(plan));
466            }
467            find_graph_start_candidates(right, start_alias, cte_env)
468        }
469        Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
470        _ => Ok(None),
471    }
472}
473
474fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
475    let start_alias = &match_clause.pattern.start.alias;
476    if let Some(where_clause) = &match_clause.where_clause
477        && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
478    {
479        return Ok(expr);
480    }
481
482    Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
483        table: None,
484        column: start_alias.clone(),
485    }))
486}
487
488fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
489    match expr {
490        Expr::BinaryOp {
491            left,
492            op: BinOp::Eq,
493            right,
494        } => {
495            if is_graph_start_id_ref(left, start_alias) {
496                Some((**right).clone())
497            } else if is_graph_start_id_ref(right, start_alias) {
498                Some((**left).clone())
499            } else {
500                None
501            }
502        }
503        Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
504            .or_else(|| find_graph_start_expr(right, start_alias)),
505        Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
506        _ => None,
507    }
508}
509
510fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
511    matches!(
512        expr,
513        Expr::Column(contextdb_parser::ast::ColumnRef {
514            table: Some(table),
515            column
516        }) if table == start_alias && column == "id"
517    )
518}
519
520/// Stub bridge from parser-AST `SortDirection` (Asc/Desc/CosineDistance) to
521/// core engine `SortDirection` (Asc/Desc). CosineDistance is not meaningful
522/// for B-tree indexes; reject it as a parse-level error.
523fn map_parser_to_core_sort_direction(
524    dir: contextdb_parser::ast::SortDirection,
525) -> Result<contextdb_core::SortDirection> {
526    match dir {
527        contextdb_parser::ast::SortDirection::Asc => Ok(contextdb_core::SortDirection::Asc),
528        contextdb_parser::ast::SortDirection::Desc => Ok(contextdb_core::SortDirection::Desc),
529        contextdb_parser::ast::SortDirection::CosineDistance => Err(Error::ParseError(
530            "CosineDistance is not a valid CREATE INDEX direction".to_string(),
531        )),
532    }
533}
534
535/// Stub: returns None so the planner always uses Scan. Impl must inspect the
536/// WHERE clause, match eligible predicate shapes, consult TableMeta.indexes,
537/// and return Some(IndexScan { ... }) when applicable.
538#[allow(dead_code)]
539fn try_plan_index_scan(
540    _table: &str,
541    _where_clause: Option<&Expr>,
542    _indexes: &[contextdb_core::table_meta::IndexDecl],
543) -> Option<crate::plan::PhysicalPlan> {
544    None
545}