Skip to main content

contextdb_planner/
planner.rs

1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4    AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5    SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14    match stmt {
15        Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16            name: ct.name.clone(),
17            columns: ct.columns.clone(),
18            unique_constraints: ct.unique_constraints.clone(),
19            immutable: ct.immutable,
20            state_machine: ct.state_machine.clone(),
21            dag_edge_types: ct.dag_edge_types.clone(),
22            propagation_rules: extract_propagation_rules(ct)?,
23            retain: ct.retain.clone(),
24        })),
25        Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
26            table: at.table.clone(),
27            action: at.action.clone(),
28        })),
29        Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
30        Statement::CreateIndex(ci) => {
31            let mut columns = Vec::with_capacity(ci.columns.len());
32            for (col, dir) in &ci.columns {
33                columns.push((col.clone(), map_parser_to_core_sort_direction(*dir)?));
34            }
35            Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
36                name: ci.name.clone(),
37                table: ci.table.clone(),
38                columns,
39            }))
40        }
41        Statement::DropIndex(di) => Ok(PhysicalPlan::DropIndex(DropIndexPlan {
42            name: di.name.clone(),
43            table: di.table.clone(),
44            if_exists: di.if_exists,
45        })),
46        Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
47            table: i.table.clone(),
48            columns: i.columns.clone(),
49            values: i.values.clone(),
50            on_conflict: i.on_conflict.clone().map(Into::into),
51        })),
52        Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
53            table: d.table.clone(),
54            where_clause: d.where_clause.clone(),
55        })),
56        Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
57            table: u.table.clone(),
58            assignments: u.assignments.clone(),
59            where_clause: u.where_clause.clone(),
60        })),
61        Statement::Select(sel) => plan_select(sel),
62        Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
63        Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
64        Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
65        Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
66        Statement::SetSyncConflictPolicy(policy) => {
67            Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
68        }
69        Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
70        Statement::ShowVectorIndexes => Ok(PhysicalPlan::ShowVectorIndexes),
71        Statement::Begin | Statement::Commit | Statement::Rollback => {
72            Ok(PhysicalPlan::Pipeline(vec![]))
73        }
74    }
75}
76
77fn extract_propagation_rules(
78    ct: &contextdb_parser::ast::CreateTable,
79) -> Result<Vec<PropagationRule>> {
80    let mut rules = Vec::new();
81
82    for column in &ct.columns {
83        if let Some(fk) = &column.references {
84            for rule in &fk.propagation_rules {
85                if let AstPropagationRule::FkState {
86                    trigger_state,
87                    target_state,
88                    max_depth,
89                    abort_on_failure,
90                } = rule
91                {
92                    rules.push(PropagationRule::ForeignKey {
93                        fk_column: column.name.clone(),
94                        referenced_table: fk.table.clone(),
95                        referenced_column: fk.column.clone(),
96                        trigger_state: trigger_state.clone(),
97                        target_state: target_state.clone(),
98                        max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
99                        abort_on_failure: *abort_on_failure,
100                    });
101                }
102            }
103        }
104    }
105
106    for rule in &ct.propagation_rules {
107        match rule {
108            AstPropagationRule::EdgeState {
109                edge_type,
110                direction,
111                trigger_state,
112                target_state,
113                max_depth,
114                abort_on_failure,
115            } => {
116                let direction = match direction.to_ascii_uppercase().as_str() {
117                    "OUTGOING" => Direction::Outgoing,
118                    "INCOMING" => Direction::Incoming,
119                    "BOTH" => Direction::Both,
120                    other => {
121                        return Err(Error::PlanError(format!(
122                            "invalid edge direction in propagation rule: {}",
123                            other
124                        )));
125                    }
126                };
127                rules.push(PropagationRule::Edge {
128                    edge_type: edge_type.clone(),
129                    direction,
130                    trigger_state: trigger_state.clone(),
131                    target_state: target_state.clone(),
132                    max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
133                    abort_on_failure: *abort_on_failure,
134                });
135            }
136            AstPropagationRule::VectorExclusion { trigger_state } => {
137                rules.push(PropagationRule::VectorExclusion {
138                    trigger_state: trigger_state.clone(),
139                });
140            }
141            AstPropagationRule::FkState { .. } => {}
142        }
143    }
144
145    Ok(rules)
146}
147
148fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
149    let mut cte_env = HashMap::new();
150
151    for cte in &sel.ctes {
152        match cte {
153            Cte::MatchCte { name, match_clause } => {
154                cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
155            }
156            Cte::SqlCte { name, query } => {
157                cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
158            }
159        }
160    }
161
162    plan_select_body(&sel.body, &cte_env)
163}
164
165fn plan_select_body(
166    body: &SelectBody,
167    cte_env: &HashMap<String, PhysicalPlan>,
168) -> Result<PhysicalPlan> {
169    let graph_from = body
170        .from
171        .iter()
172        .find(|f| matches!(f, FromItem::GraphTable { .. }));
173
174    let mut current = if let Some(from_item) = graph_from {
175        graph_plan_from_from_item(from_item, cte_env)?
176    } else {
177        let from_item = body.from.iter().find_map(|item| match item {
178            FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
179            FromItem::GraphTable { .. } => None,
180        });
181
182        match from_item {
183            Some((from_table, from_alias)) => {
184                if let Some(cte_plan) = cte_env.get(&from_table) {
185                    let mut cte_plan = cte_plan.clone();
186                    if body.joins.is_empty()
187                        && let Some(where_clause) = &body.where_clause
188                    {
189                        cte_plan = PhysicalPlan::Filter {
190                            input: Box::new(cte_plan),
191                            predicate: where_clause.clone(),
192                        };
193                    }
194                    cte_plan
195                } else {
196                    PhysicalPlan::Scan {
197                        table: from_table,
198                        alias: from_alias.clone(),
199                        filter: if body.joins.is_empty() {
200                            body.where_clause.clone()
201                        } else {
202                            None
203                        },
204                    }
205                }
206            }
207            None => PhysicalPlan::Scan {
208                table: "dual".to_string(),
209                alias: None,
210                filter: None,
211            },
212        }
213    };
214
215    if !body.joins.is_empty() {
216        let left_alias = body.from.iter().find_map(|item| match item {
217            FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
218            FromItem::GraphTable { .. } => None,
219        });
220
221        for join in &body.joins {
222            let right = if let Some(cte_plan) = cte_env.get(&join.table) {
223                cte_plan.clone()
224            } else {
225                PhysicalPlan::Scan {
226                    table: join.table.clone(),
227                    alias: join.alias.clone(),
228                    filter: None,
229                }
230            };
231
232            current = PhysicalPlan::Join {
233                left: Box::new(current),
234                right: Box::new(right),
235                condition: join.on.clone(),
236                join_type: match join.join_type {
237                    contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
238                    contextdb_parser::ast::JoinType::Left => JoinType::Left,
239                },
240                left_alias: left_alias.clone(),
241                right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
242            };
243        }
244
245        if let Some(where_clause) = &body.where_clause {
246            current = PhysicalPlan::Filter {
247                input: Box::new(current),
248                predicate: where_clause.clone(),
249            };
250        }
251    }
252
253    let uses_vector_search = body
254        .order_by
255        .first()
256        .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
257
258    if body.use_rank.is_some() && !uses_vector_search {
259        return Err(Error::UseRankRequiresVectorOrder);
260    }
261
262    if let Some(order) = body.order_by.first()
263        && matches!(order.direction, SortDirection::CosineDistance)
264    {
265        let k = body.limit.ok_or_else(|| {
266            if body.use_rank.is_some() {
267                Error::UseRankRequiresLimit
268            } else {
269                Error::UnboundedVectorSearch
270            }
271        })?;
272        let vector_table = vector_base_table(&current)?.ok_or_else(|| {
273            Error::PlanError("unable to resolve physical vector source table".to_string())
274        })?;
275        let (column, query_expr) = vector_search_parts(&order.expr)?;
276        current = PhysicalPlan::VectorSearch {
277            table: vector_table,
278            column,
279            query_expr,
280            k,
281            candidates: Some(Box::new(current)),
282            sort_key: body.use_rank.clone(),
283        };
284    }
285
286    if !body.order_by.is_empty() && !uses_vector_search {
287        current = PhysicalPlan::Sort {
288            input: Box::new(current),
289            keys: body
290                .order_by
291                .iter()
292                .map(|item| SortKey {
293                    expr: item.expr.clone(),
294                    direction: item.direction,
295                })
296                .collect(),
297        };
298    }
299
300    let is_select_star = matches!(
301        body.columns.as_slice(),
302        [contextdb_parser::ast::SelectColumn {
303            expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
304            alias: None
305        }] if column == "*"
306    );
307    if !is_select_star {
308        current = PhysicalPlan::Project {
309            input: Box::new(current),
310            columns: body
311                .columns
312                .iter()
313                .map(|column| ProjectColumn {
314                    expr: column.expr.clone(),
315                    alias: column.alias.clone(),
316                })
317                .collect(),
318        };
319    }
320
321    if body.distinct {
322        current = PhysicalPlan::Distinct {
323            input: Box::new(current),
324        };
325    }
326
327    if let Some(limit) = body.limit
328        && !uses_vector_search
329    {
330        current = PhysicalPlan::Limit {
331            input: Box::new(current),
332            count: limit,
333        };
334    }
335
336    Ok(current)
337}
338
339fn vector_search_parts(expr: &Expr) -> Result<(String, Expr)> {
340    match expr {
341        Expr::CosineDistance { left, right } => match left.as_ref() {
342            Expr::Column(column) => Ok((column.column.clone(), right.as_ref().clone())),
343            _ => Err(Error::PlanError(
344                "left side of vector distance must be a column".to_string(),
345            )),
346        },
347        _ => Err(Error::PlanError(
348            "vector search requires a cosine distance expression".to_string(),
349        )),
350    }
351}
352
353fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
354    match plan {
355        PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
356            Ok(Some(table.clone()))
357        }
358        PhysicalPlan::Filter { input, .. }
359        | PhysicalPlan::Project { input, .. }
360        | PhysicalPlan::Distinct { input }
361        | PhysicalPlan::Sort { input, .. }
362        | PhysicalPlan::Limit { input, .. }
363        | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
364        PhysicalPlan::Join { left, right, .. } => {
365            let left_table = vector_base_table(left)?;
366            let right_table = vector_base_table(right)?;
367            match (left_table, right_table) {
368                (Some(left), Some(right)) if left == right => Ok(Some(left)),
369                (Some(_), Some(_)) => Err(Error::PlanError(
370                    "ambiguous physical vector source table in join".to_string(),
371                )),
372                (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
373                (None, None) => Ok(None),
374            }
375        }
376        PhysicalPlan::Pipeline(plans) => {
377            for plan in plans.iter().rev() {
378                if let Some(table) = vector_base_table(plan)? {
379                    return Ok(Some(table));
380                }
381            }
382            Ok(None)
383        }
384        PhysicalPlan::GraphBfs { .. }
385        | PhysicalPlan::CteRef { .. }
386        | PhysicalPlan::Union { .. } => Ok(None),
387        _ => Ok(None),
388    }
389}
390
391fn graph_plan_from_from_item(
392    from_item: &FromItem,
393    cte_env: &HashMap<String, PhysicalPlan>,
394) -> Result<PhysicalPlan> {
395    match from_item {
396        FromItem::GraphTable {
397            match_clause,
398            columns,
399            ..
400        } => {
401            let bfs = graph_bfs_from_match(match_clause, cte_env)?;
402            if columns.is_empty() {
403                Ok(bfs)
404            } else {
405                Ok(PhysicalPlan::Project {
406                    input: Box::new(bfs),
407                    columns: columns
408                        .iter()
409                        .map(|c| ProjectColumn {
410                            expr: c.expr.clone(),
411                            alias: Some(c.alias.clone()),
412                        })
413                        .collect(),
414                })
415            }
416        }
417        FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
418            table: name.clone(),
419            alias: None,
420            filter: None,
421        }),
422    }
423}
424
425fn graph_bfs_from_match(
426    match_clause: &contextdb_parser::ast::MatchClause,
427    cte_env: &HashMap<String, PhysicalPlan>,
428) -> Result<PhysicalPlan> {
429    let steps = match_clause
430        .pattern
431        .edges
432        .iter()
433        .map(|step| {
434            let max_depth = if step.max_hops == 0 {
435                DEFAULT_MATCH_DEPTH
436            } else {
437                step.max_hops
438            };
439            if max_depth > ENGINE_MAX_BFS_DEPTH {
440                return Err(Error::BfsDepthExceeded(max_depth));
441            }
442
443            Ok(GraphStepPlan {
444                edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
445                direction: match step.direction {
446                    contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
447                    contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
448                    contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
449                },
450                min_depth: step.min_hops.max(1),
451                max_depth,
452                target_alias: step.target.alias.clone(),
453            })
454        })
455        .collect::<Result<Vec<_>>>()?;
456    if steps.is_empty() {
457        return Err(Error::PlanError(
458            "MATCH must include at least one edge".into(),
459        ));
460    }
461
462    Ok(PhysicalPlan::GraphBfs {
463        start_alias: match_clause.pattern.start.alias.clone(),
464        start_expr: extract_graph_start_expr(match_clause)?,
465        start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
466        steps,
467        filter: match_clause.where_clause.clone(),
468    })
469}
470
471fn extract_graph_start_candidates(
472    match_clause: &MatchClause,
473    cte_env: &HashMap<String, PhysicalPlan>,
474) -> Result<Option<Box<PhysicalPlan>>> {
475    let Some(where_clause) = &match_clause.where_clause else {
476        return Ok(None);
477    };
478    find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
479}
480
481fn find_graph_start_candidates(
482    expr: &Expr,
483    start_alias: &str,
484    cte_env: &HashMap<String, PhysicalPlan>,
485) -> Result<Option<Box<PhysicalPlan>>> {
486    match expr {
487        Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
488            Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
489        }
490        Expr::BinaryOp { left, right, .. } => {
491            if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
492                return Ok(Some(plan));
493            }
494            find_graph_start_candidates(right, start_alias, cte_env)
495        }
496        Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
497        _ => Ok(None),
498    }
499}
500
501fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
502    let start_alias = &match_clause.pattern.start.alias;
503    if let Some(where_clause) = &match_clause.where_clause
504        && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
505    {
506        return Ok(expr);
507    }
508
509    Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
510        table: None,
511        column: start_alias.clone(),
512    }))
513}
514
515fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
516    match expr {
517        Expr::BinaryOp {
518            left,
519            op: BinOp::Eq,
520            right,
521        } => {
522            if is_graph_start_id_ref(left, start_alias) {
523                Some((**right).clone())
524            } else if is_graph_start_id_ref(right, start_alias) {
525                Some((**left).clone())
526            } else {
527                None
528            }
529        }
530        Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
531            .or_else(|| find_graph_start_expr(right, start_alias)),
532        Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
533        _ => None,
534    }
535}
536
537fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
538    matches!(
539        expr,
540        Expr::Column(contextdb_parser::ast::ColumnRef {
541            table: Some(table),
542            column
543        }) if table == start_alias && column == "id"
544    )
545}
546
547/// Stub bridge from parser-AST `SortDirection` (Asc/Desc/CosineDistance) to
548/// core engine `SortDirection` (Asc/Desc). CosineDistance is not meaningful
549/// for B-tree indexes; reject it as a parse-level error.
550fn map_parser_to_core_sort_direction(
551    dir: contextdb_parser::ast::SortDirection,
552) -> Result<contextdb_core::SortDirection> {
553    match dir {
554        contextdb_parser::ast::SortDirection::Asc => Ok(contextdb_core::SortDirection::Asc),
555        contextdb_parser::ast::SortDirection::Desc => Ok(contextdb_core::SortDirection::Desc),
556        contextdb_parser::ast::SortDirection::CosineDistance => Err(Error::ParseError(
557            "CosineDistance is not a valid CREATE INDEX direction".to_string(),
558        )),
559    }
560}
561
562/// Stub: returns None so the planner always uses Scan. Impl must inspect the
563/// WHERE clause, match eligible predicate shapes, consult TableMeta.indexes,
564/// and return Some(IndexScan { ... }) when applicable.
565#[allow(dead_code)]
566fn try_plan_index_scan(
567    _table: &str,
568    _where_clause: Option<&Expr>,
569    _indexes: &[contextdb_core::table_meta::IndexDecl],
570) -> Option<crate::plan::PhysicalPlan> {
571    None
572}