Skip to main content

contextdb_planner/
planner.rs

1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4    AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5    SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14    match stmt {
15        Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16            name: ct.name.clone(),
17            columns: ct.columns.clone(),
18            immutable: ct.immutable,
19            state_machine: ct.state_machine.clone(),
20            dag_edge_types: ct.dag_edge_types.clone(),
21            propagation_rules: extract_propagation_rules(ct)?,
22            retain: ct.retain.clone(),
23        })),
24        Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
25            table: at.table.clone(),
26            action: at.action.clone(),
27        })),
28        Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
29        Statement::CreateIndex(ci) => Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
30            name: ci.name.clone(),
31            table: ci.table.clone(),
32            columns: ci.columns.clone(),
33        })),
34        Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
35            table: i.table.clone(),
36            columns: i.columns.clone(),
37            values: i.values.clone(),
38            on_conflict: i.on_conflict.clone().map(Into::into),
39        })),
40        Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
41            table: d.table.clone(),
42            where_clause: d.where_clause.clone(),
43        })),
44        Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
45            table: u.table.clone(),
46            assignments: u.assignments.clone(),
47            where_clause: u.where_clause.clone(),
48        })),
49        Statement::Select(sel) => plan_select(sel),
50        Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
51        Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
52        Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
53        Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
54        Statement::SetSyncConflictPolicy(policy) => {
55            Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
56        }
57        Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
58        Statement::Begin | Statement::Commit | Statement::Rollback => {
59            Ok(PhysicalPlan::Pipeline(vec![]))
60        }
61    }
62}
63
64fn extract_propagation_rules(
65    ct: &contextdb_parser::ast::CreateTable,
66) -> Result<Vec<PropagationRule>> {
67    let mut rules = Vec::new();
68
69    for column in &ct.columns {
70        if let Some(fk) = &column.references {
71            for rule in &fk.propagation_rules {
72                if let AstPropagationRule::FkState {
73                    trigger_state,
74                    target_state,
75                    max_depth,
76                    abort_on_failure,
77                } = rule
78                {
79                    rules.push(PropagationRule::ForeignKey {
80                        fk_column: column.name.clone(),
81                        referenced_table: fk.table.clone(),
82                        referenced_column: fk.column.clone(),
83                        trigger_state: trigger_state.clone(),
84                        target_state: target_state.clone(),
85                        max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
86                        abort_on_failure: *abort_on_failure,
87                    });
88                }
89            }
90        }
91    }
92
93    for rule in &ct.propagation_rules {
94        match rule {
95            AstPropagationRule::EdgeState {
96                edge_type,
97                direction,
98                trigger_state,
99                target_state,
100                max_depth,
101                abort_on_failure,
102            } => {
103                let direction = match direction.to_ascii_uppercase().as_str() {
104                    "OUTGOING" => Direction::Outgoing,
105                    "INCOMING" => Direction::Incoming,
106                    "BOTH" => Direction::Both,
107                    other => {
108                        return Err(Error::PlanError(format!(
109                            "invalid edge direction in propagation rule: {}",
110                            other
111                        )));
112                    }
113                };
114                rules.push(PropagationRule::Edge {
115                    edge_type: edge_type.clone(),
116                    direction,
117                    trigger_state: trigger_state.clone(),
118                    target_state: target_state.clone(),
119                    max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
120                    abort_on_failure: *abort_on_failure,
121                });
122            }
123            AstPropagationRule::VectorExclusion { trigger_state } => {
124                rules.push(PropagationRule::VectorExclusion {
125                    trigger_state: trigger_state.clone(),
126                });
127            }
128            AstPropagationRule::FkState { .. } => {}
129        }
130    }
131
132    Ok(rules)
133}
134
135fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
136    let mut cte_env = HashMap::new();
137
138    for cte in &sel.ctes {
139        match cte {
140            Cte::MatchCte { name, match_clause } => {
141                cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
142            }
143            Cte::SqlCte { name, query } => {
144                cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
145            }
146        }
147    }
148
149    plan_select_body(&sel.body, &cte_env)
150}
151
152fn plan_select_body(
153    body: &SelectBody,
154    cte_env: &HashMap<String, PhysicalPlan>,
155) -> Result<PhysicalPlan> {
156    let graph_from = body
157        .from
158        .iter()
159        .find(|f| matches!(f, FromItem::GraphTable { .. }));
160
161    let mut current = if let Some(from_item) = graph_from {
162        graph_plan_from_from_item(from_item, cte_env)?
163    } else {
164        let from_item = body.from.iter().find_map(|item| match item {
165            FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
166            FromItem::GraphTable { .. } => None,
167        });
168
169        match from_item {
170            Some((from_table, from_alias)) => {
171                if let Some(cte_plan) = cte_env.get(&from_table) {
172                    let mut cte_plan = cte_plan.clone();
173                    if body.joins.is_empty()
174                        && let Some(where_clause) = &body.where_clause
175                    {
176                        cte_plan = PhysicalPlan::Filter {
177                            input: Box::new(cte_plan),
178                            predicate: where_clause.clone(),
179                        };
180                    }
181                    cte_plan
182                } else {
183                    PhysicalPlan::Scan {
184                        table: from_table,
185                        alias: from_alias.clone(),
186                        filter: if body.joins.is_empty() {
187                            body.where_clause.clone()
188                        } else {
189                            None
190                        },
191                    }
192                }
193            }
194            None => PhysicalPlan::Scan {
195                table: "dual".to_string(),
196                alias: None,
197                filter: None,
198            },
199        }
200    };
201
202    if !body.joins.is_empty() {
203        let left_alias = body.from.iter().find_map(|item| match item {
204            FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
205            FromItem::GraphTable { .. } => None,
206        });
207
208        for join in &body.joins {
209            let right = if let Some(cte_plan) = cte_env.get(&join.table) {
210                cte_plan.clone()
211            } else {
212                PhysicalPlan::Scan {
213                    table: join.table.clone(),
214                    alias: join.alias.clone(),
215                    filter: None,
216                }
217            };
218
219            current = PhysicalPlan::Join {
220                left: Box::new(current),
221                right: Box::new(right),
222                condition: join.on.clone(),
223                join_type: match join.join_type {
224                    contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
225                    contextdb_parser::ast::JoinType::Left => JoinType::Left,
226                },
227                left_alias: left_alias.clone(),
228                right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
229            };
230        }
231
232        if let Some(where_clause) = &body.where_clause {
233            current = PhysicalPlan::Filter {
234                input: Box::new(current),
235                predicate: where_clause.clone(),
236            };
237        }
238    }
239
240    let uses_vector_search = body
241        .order_by
242        .first()
243        .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
244
245    if let Some(order) = body.order_by.first()
246        && matches!(order.direction, SortDirection::CosineDistance)
247    {
248        let k = body.limit.ok_or(Error::UnboundedVectorSearch)?;
249        let vector_table = vector_base_table(&current)?.ok_or_else(|| {
250            Error::PlanError("unable to resolve physical vector source table".to_string())
251        })?;
252        current = PhysicalPlan::VectorSearch {
253            table: vector_table,
254            column: "embedding".to_string(),
255            query_expr: order.expr.clone(),
256            k,
257            candidates: Some(Box::new(current)),
258        };
259    }
260
261    if !body.order_by.is_empty() && !uses_vector_search {
262        current = PhysicalPlan::Sort {
263            input: Box::new(current),
264            keys: body
265                .order_by
266                .iter()
267                .map(|item| SortKey {
268                    expr: item.expr.clone(),
269                    direction: item.direction,
270                })
271                .collect(),
272        };
273    }
274
275    let is_select_star = matches!(
276        body.columns.as_slice(),
277        [contextdb_parser::ast::SelectColumn {
278            expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
279            alias: None
280        }] if column == "*"
281    );
282    if !is_select_star {
283        current = PhysicalPlan::Project {
284            input: Box::new(current),
285            columns: body
286                .columns
287                .iter()
288                .map(|column| ProjectColumn {
289                    expr: column.expr.clone(),
290                    alias: column.alias.clone(),
291                })
292                .collect(),
293        };
294    }
295
296    if body.distinct {
297        current = PhysicalPlan::Distinct {
298            input: Box::new(current),
299        };
300    }
301
302    if let Some(limit) = body.limit
303        && !uses_vector_search
304    {
305        current = PhysicalPlan::Limit {
306            input: Box::new(current),
307            count: limit,
308        };
309    }
310
311    Ok(current)
312}
313
314fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
315    match plan {
316        PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
317            Ok(Some(table.clone()))
318        }
319        PhysicalPlan::Filter { input, .. }
320        | PhysicalPlan::Project { input, .. }
321        | PhysicalPlan::Distinct { input }
322        | PhysicalPlan::Sort { input, .. }
323        | PhysicalPlan::Limit { input, .. }
324        | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
325        PhysicalPlan::Join { left, right, .. } => {
326            let left_table = vector_base_table(left)?;
327            let right_table = vector_base_table(right)?;
328            match (left_table, right_table) {
329                (Some(left), Some(right)) if left == right => Ok(Some(left)),
330                (Some(_), Some(_)) => Err(Error::PlanError(
331                    "ambiguous physical vector source table in join".to_string(),
332                )),
333                (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
334                (None, None) => Ok(None),
335            }
336        }
337        PhysicalPlan::Pipeline(plans) => {
338            for plan in plans.iter().rev() {
339                if let Some(table) = vector_base_table(plan)? {
340                    return Ok(Some(table));
341                }
342            }
343            Ok(None)
344        }
345        PhysicalPlan::GraphBfs { .. }
346        | PhysicalPlan::CteRef { .. }
347        | PhysicalPlan::Union { .. } => Ok(None),
348        _ => Ok(None),
349    }
350}
351
352fn graph_plan_from_from_item(
353    from_item: &FromItem,
354    cte_env: &HashMap<String, PhysicalPlan>,
355) -> Result<PhysicalPlan> {
356    match from_item {
357        FromItem::GraphTable {
358            match_clause,
359            columns,
360            ..
361        } => {
362            let bfs = graph_bfs_from_match(match_clause, cte_env)?;
363            if columns.is_empty() {
364                Ok(bfs)
365            } else {
366                Ok(PhysicalPlan::Project {
367                    input: Box::new(bfs),
368                    columns: columns
369                        .iter()
370                        .map(|c| ProjectColumn {
371                            expr: c.expr.clone(),
372                            alias: Some(c.alias.clone()),
373                        })
374                        .collect(),
375                })
376            }
377        }
378        FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
379            table: name.clone(),
380            alias: None,
381            filter: None,
382        }),
383    }
384}
385
386fn graph_bfs_from_match(
387    match_clause: &contextdb_parser::ast::MatchClause,
388    cte_env: &HashMap<String, PhysicalPlan>,
389) -> Result<PhysicalPlan> {
390    let steps = match_clause
391        .pattern
392        .edges
393        .iter()
394        .map(|step| {
395            let max_depth = if step.max_hops == 0 {
396                DEFAULT_MATCH_DEPTH
397            } else {
398                step.max_hops
399            };
400            if max_depth > ENGINE_MAX_BFS_DEPTH {
401                return Err(Error::BfsDepthExceeded(max_depth));
402            }
403
404            Ok(GraphStepPlan {
405                edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
406                direction: match step.direction {
407                    contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
408                    contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
409                    contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
410                },
411                min_depth: step.min_hops.max(1),
412                max_depth,
413                target_alias: step.target.alias.clone(),
414            })
415        })
416        .collect::<Result<Vec<_>>>()?;
417    if steps.is_empty() {
418        return Err(Error::PlanError(
419            "MATCH must include at least one edge".into(),
420        ));
421    }
422
423    Ok(PhysicalPlan::GraphBfs {
424        start_alias: match_clause.pattern.start.alias.clone(),
425        start_expr: extract_graph_start_expr(match_clause)?,
426        start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
427        steps,
428        filter: match_clause.where_clause.clone(),
429    })
430}
431
432fn extract_graph_start_candidates(
433    match_clause: &MatchClause,
434    cte_env: &HashMap<String, PhysicalPlan>,
435) -> Result<Option<Box<PhysicalPlan>>> {
436    let Some(where_clause) = &match_clause.where_clause else {
437        return Ok(None);
438    };
439    find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
440}
441
442fn find_graph_start_candidates(
443    expr: &Expr,
444    start_alias: &str,
445    cte_env: &HashMap<String, PhysicalPlan>,
446) -> Result<Option<Box<PhysicalPlan>>> {
447    match expr {
448        Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
449            Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
450        }
451        Expr::BinaryOp { left, right, .. } => {
452            if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
453                return Ok(Some(plan));
454            }
455            find_graph_start_candidates(right, start_alias, cte_env)
456        }
457        Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
458        _ => Ok(None),
459    }
460}
461
462fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
463    let start_alias = &match_clause.pattern.start.alias;
464    if let Some(where_clause) = &match_clause.where_clause
465        && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
466    {
467        return Ok(expr);
468    }
469
470    Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
471        table: None,
472        column: start_alias.clone(),
473    }))
474}
475
476fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
477    match expr {
478        Expr::BinaryOp {
479            left,
480            op: BinOp::Eq,
481            right,
482        } => {
483            if is_graph_start_id_ref(left, start_alias) {
484                Some((**right).clone())
485            } else if is_graph_start_id_ref(right, start_alias) {
486                Some((**left).clone())
487            } else {
488                None
489            }
490        }
491        Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
492            .or_else(|| find_graph_start_expr(right, start_alias)),
493        Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
494        _ => None,
495    }
496}
497
498fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
499    matches!(
500        expr,
501        Expr::Column(contextdb_parser::ast::ColumnRef {
502            table: Some(table),
503            column
504        }) if table == start_alias && column == "id"
505    )
506}