1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4 AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5 SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14 match stmt {
15 Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16 name: ct.name.clone(),
17 columns: ct.columns.clone(),
18 unique_constraints: ct.unique_constraints.clone(),
19 immutable: ct.immutable,
20 state_machine: ct.state_machine.clone(),
21 dag_edge_types: ct.dag_edge_types.clone(),
22 propagation_rules: extract_propagation_rules(ct)?,
23 retain: ct.retain.clone(),
24 })),
25 Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
26 table: at.table.clone(),
27 action: at.action.clone(),
28 })),
29 Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
30 Statement::CreateIndex(ci) => {
31 let mut columns = Vec::with_capacity(ci.columns.len());
32 for (col, dir) in &ci.columns {
33 columns.push((col.clone(), map_parser_to_core_sort_direction(*dir)?));
34 }
35 Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
36 name: ci.name.clone(),
37 table: ci.table.clone(),
38 columns,
39 }))
40 }
41 Statement::DropIndex(di) => Ok(PhysicalPlan::DropIndex(DropIndexPlan {
42 name: di.name.clone(),
43 table: di.table.clone(),
44 if_exists: di.if_exists,
45 })),
46 Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
47 table: i.table.clone(),
48 columns: i.columns.clone(),
49 values: i.values.clone(),
50 on_conflict: i.on_conflict.clone().map(Into::into),
51 })),
52 Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
53 table: d.table.clone(),
54 where_clause: d.where_clause.clone(),
55 })),
56 Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
57 table: u.table.clone(),
58 assignments: u.assignments.clone(),
59 where_clause: u.where_clause.clone(),
60 })),
61 Statement::Select(sel) => plan_select(sel),
62 Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
63 Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
64 Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
65 Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
66 Statement::SetSyncConflictPolicy(policy) => {
67 Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
68 }
69 Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
70 Statement::Begin | Statement::Commit | Statement::Rollback => {
71 Ok(PhysicalPlan::Pipeline(vec![]))
72 }
73 }
74}
75
76fn extract_propagation_rules(
77 ct: &contextdb_parser::ast::CreateTable,
78) -> Result<Vec<PropagationRule>> {
79 let mut rules = Vec::new();
80
81 for column in &ct.columns {
82 if let Some(fk) = &column.references {
83 for rule in &fk.propagation_rules {
84 if let AstPropagationRule::FkState {
85 trigger_state,
86 target_state,
87 max_depth,
88 abort_on_failure,
89 } = rule
90 {
91 rules.push(PropagationRule::ForeignKey {
92 fk_column: column.name.clone(),
93 referenced_table: fk.table.clone(),
94 referenced_column: fk.column.clone(),
95 trigger_state: trigger_state.clone(),
96 target_state: target_state.clone(),
97 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
98 abort_on_failure: *abort_on_failure,
99 });
100 }
101 }
102 }
103 }
104
105 for rule in &ct.propagation_rules {
106 match rule {
107 AstPropagationRule::EdgeState {
108 edge_type,
109 direction,
110 trigger_state,
111 target_state,
112 max_depth,
113 abort_on_failure,
114 } => {
115 let direction = match direction.to_ascii_uppercase().as_str() {
116 "OUTGOING" => Direction::Outgoing,
117 "INCOMING" => Direction::Incoming,
118 "BOTH" => Direction::Both,
119 other => {
120 return Err(Error::PlanError(format!(
121 "invalid edge direction in propagation rule: {}",
122 other
123 )));
124 }
125 };
126 rules.push(PropagationRule::Edge {
127 edge_type: edge_type.clone(),
128 direction,
129 trigger_state: trigger_state.clone(),
130 target_state: target_state.clone(),
131 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
132 abort_on_failure: *abort_on_failure,
133 });
134 }
135 AstPropagationRule::VectorExclusion { trigger_state } => {
136 rules.push(PropagationRule::VectorExclusion {
137 trigger_state: trigger_state.clone(),
138 });
139 }
140 AstPropagationRule::FkState { .. } => {}
141 }
142 }
143
144 Ok(rules)
145}
146
147fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
148 let mut cte_env = HashMap::new();
149
150 for cte in &sel.ctes {
151 match cte {
152 Cte::MatchCte { name, match_clause } => {
153 cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
154 }
155 Cte::SqlCte { name, query } => {
156 cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
157 }
158 }
159 }
160
161 plan_select_body(&sel.body, &cte_env)
162}
163
164fn plan_select_body(
165 body: &SelectBody,
166 cte_env: &HashMap<String, PhysicalPlan>,
167) -> Result<PhysicalPlan> {
168 let graph_from = body
169 .from
170 .iter()
171 .find(|f| matches!(f, FromItem::GraphTable { .. }));
172
173 let mut current = if let Some(from_item) = graph_from {
174 graph_plan_from_from_item(from_item, cte_env)?
175 } else {
176 let from_item = body.from.iter().find_map(|item| match item {
177 FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
178 FromItem::GraphTable { .. } => None,
179 });
180
181 match from_item {
182 Some((from_table, from_alias)) => {
183 if let Some(cte_plan) = cte_env.get(&from_table) {
184 let mut cte_plan = cte_plan.clone();
185 if body.joins.is_empty()
186 && let Some(where_clause) = &body.where_clause
187 {
188 cte_plan = PhysicalPlan::Filter {
189 input: Box::new(cte_plan),
190 predicate: where_clause.clone(),
191 };
192 }
193 cte_plan
194 } else {
195 PhysicalPlan::Scan {
196 table: from_table,
197 alias: from_alias.clone(),
198 filter: if body.joins.is_empty() {
199 body.where_clause.clone()
200 } else {
201 None
202 },
203 }
204 }
205 }
206 None => PhysicalPlan::Scan {
207 table: "dual".to_string(),
208 alias: None,
209 filter: None,
210 },
211 }
212 };
213
214 if !body.joins.is_empty() {
215 let left_alias = body.from.iter().find_map(|item| match item {
216 FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
217 FromItem::GraphTable { .. } => None,
218 });
219
220 for join in &body.joins {
221 let right = if let Some(cte_plan) = cte_env.get(&join.table) {
222 cte_plan.clone()
223 } else {
224 PhysicalPlan::Scan {
225 table: join.table.clone(),
226 alias: join.alias.clone(),
227 filter: None,
228 }
229 };
230
231 current = PhysicalPlan::Join {
232 left: Box::new(current),
233 right: Box::new(right),
234 condition: join.on.clone(),
235 join_type: match join.join_type {
236 contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
237 contextdb_parser::ast::JoinType::Left => JoinType::Left,
238 },
239 left_alias: left_alias.clone(),
240 right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
241 };
242 }
243
244 if let Some(where_clause) = &body.where_clause {
245 current = PhysicalPlan::Filter {
246 input: Box::new(current),
247 predicate: where_clause.clone(),
248 };
249 }
250 }
251
252 let uses_vector_search = body
253 .order_by
254 .first()
255 .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
256
257 if let Some(order) = body.order_by.first()
258 && matches!(order.direction, SortDirection::CosineDistance)
259 {
260 let k = body.limit.ok_or(Error::UnboundedVectorSearch)?;
261 let vector_table = vector_base_table(¤t)?.ok_or_else(|| {
262 Error::PlanError("unable to resolve physical vector source table".to_string())
263 })?;
264 current = PhysicalPlan::VectorSearch {
265 table: vector_table,
266 column: "embedding".to_string(),
267 query_expr: order.expr.clone(),
268 k,
269 candidates: Some(Box::new(current)),
270 };
271 }
272
273 if !body.order_by.is_empty() && !uses_vector_search {
274 current = PhysicalPlan::Sort {
275 input: Box::new(current),
276 keys: body
277 .order_by
278 .iter()
279 .map(|item| SortKey {
280 expr: item.expr.clone(),
281 direction: item.direction,
282 })
283 .collect(),
284 };
285 }
286
287 let is_select_star = matches!(
288 body.columns.as_slice(),
289 [contextdb_parser::ast::SelectColumn {
290 expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
291 alias: None
292 }] if column == "*"
293 );
294 if !is_select_star {
295 current = PhysicalPlan::Project {
296 input: Box::new(current),
297 columns: body
298 .columns
299 .iter()
300 .map(|column| ProjectColumn {
301 expr: column.expr.clone(),
302 alias: column.alias.clone(),
303 })
304 .collect(),
305 };
306 }
307
308 if body.distinct {
309 current = PhysicalPlan::Distinct {
310 input: Box::new(current),
311 };
312 }
313
314 if let Some(limit) = body.limit
315 && !uses_vector_search
316 {
317 current = PhysicalPlan::Limit {
318 input: Box::new(current),
319 count: limit,
320 };
321 }
322
323 Ok(current)
324}
325
326fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
327 match plan {
328 PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
329 Ok(Some(table.clone()))
330 }
331 PhysicalPlan::Filter { input, .. }
332 | PhysicalPlan::Project { input, .. }
333 | PhysicalPlan::Distinct { input }
334 | PhysicalPlan::Sort { input, .. }
335 | PhysicalPlan::Limit { input, .. }
336 | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
337 PhysicalPlan::Join { left, right, .. } => {
338 let left_table = vector_base_table(left)?;
339 let right_table = vector_base_table(right)?;
340 match (left_table, right_table) {
341 (Some(left), Some(right)) if left == right => Ok(Some(left)),
342 (Some(_), Some(_)) => Err(Error::PlanError(
343 "ambiguous physical vector source table in join".to_string(),
344 )),
345 (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
346 (None, None) => Ok(None),
347 }
348 }
349 PhysicalPlan::Pipeline(plans) => {
350 for plan in plans.iter().rev() {
351 if let Some(table) = vector_base_table(plan)? {
352 return Ok(Some(table));
353 }
354 }
355 Ok(None)
356 }
357 PhysicalPlan::GraphBfs { .. }
358 | PhysicalPlan::CteRef { .. }
359 | PhysicalPlan::Union { .. } => Ok(None),
360 _ => Ok(None),
361 }
362}
363
364fn graph_plan_from_from_item(
365 from_item: &FromItem,
366 cte_env: &HashMap<String, PhysicalPlan>,
367) -> Result<PhysicalPlan> {
368 match from_item {
369 FromItem::GraphTable {
370 match_clause,
371 columns,
372 ..
373 } => {
374 let bfs = graph_bfs_from_match(match_clause, cte_env)?;
375 if columns.is_empty() {
376 Ok(bfs)
377 } else {
378 Ok(PhysicalPlan::Project {
379 input: Box::new(bfs),
380 columns: columns
381 .iter()
382 .map(|c| ProjectColumn {
383 expr: c.expr.clone(),
384 alias: Some(c.alias.clone()),
385 })
386 .collect(),
387 })
388 }
389 }
390 FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
391 table: name.clone(),
392 alias: None,
393 filter: None,
394 }),
395 }
396}
397
398fn graph_bfs_from_match(
399 match_clause: &contextdb_parser::ast::MatchClause,
400 cte_env: &HashMap<String, PhysicalPlan>,
401) -> Result<PhysicalPlan> {
402 let steps = match_clause
403 .pattern
404 .edges
405 .iter()
406 .map(|step| {
407 let max_depth = if step.max_hops == 0 {
408 DEFAULT_MATCH_DEPTH
409 } else {
410 step.max_hops
411 };
412 if max_depth > ENGINE_MAX_BFS_DEPTH {
413 return Err(Error::BfsDepthExceeded(max_depth));
414 }
415
416 Ok(GraphStepPlan {
417 edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
418 direction: match step.direction {
419 contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
420 contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
421 contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
422 },
423 min_depth: step.min_hops.max(1),
424 max_depth,
425 target_alias: step.target.alias.clone(),
426 })
427 })
428 .collect::<Result<Vec<_>>>()?;
429 if steps.is_empty() {
430 return Err(Error::PlanError(
431 "MATCH must include at least one edge".into(),
432 ));
433 }
434
435 Ok(PhysicalPlan::GraphBfs {
436 start_alias: match_clause.pattern.start.alias.clone(),
437 start_expr: extract_graph_start_expr(match_clause)?,
438 start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
439 steps,
440 filter: match_clause.where_clause.clone(),
441 })
442}
443
444fn extract_graph_start_candidates(
445 match_clause: &MatchClause,
446 cte_env: &HashMap<String, PhysicalPlan>,
447) -> Result<Option<Box<PhysicalPlan>>> {
448 let Some(where_clause) = &match_clause.where_clause else {
449 return Ok(None);
450 };
451 find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
452}
453
454fn find_graph_start_candidates(
455 expr: &Expr,
456 start_alias: &str,
457 cte_env: &HashMap<String, PhysicalPlan>,
458) -> Result<Option<Box<PhysicalPlan>>> {
459 match expr {
460 Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
461 Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
462 }
463 Expr::BinaryOp { left, right, .. } => {
464 if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
465 return Ok(Some(plan));
466 }
467 find_graph_start_candidates(right, start_alias, cte_env)
468 }
469 Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
470 _ => Ok(None),
471 }
472}
473
474fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
475 let start_alias = &match_clause.pattern.start.alias;
476 if let Some(where_clause) = &match_clause.where_clause
477 && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
478 {
479 return Ok(expr);
480 }
481
482 Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
483 table: None,
484 column: start_alias.clone(),
485 }))
486}
487
488fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
489 match expr {
490 Expr::BinaryOp {
491 left,
492 op: BinOp::Eq,
493 right,
494 } => {
495 if is_graph_start_id_ref(left, start_alias) {
496 Some((**right).clone())
497 } else if is_graph_start_id_ref(right, start_alias) {
498 Some((**left).clone())
499 } else {
500 None
501 }
502 }
503 Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
504 .or_else(|| find_graph_start_expr(right, start_alias)),
505 Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
506 _ => None,
507 }
508}
509
510fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
511 matches!(
512 expr,
513 Expr::Column(contextdb_parser::ast::ColumnRef {
514 table: Some(table),
515 column
516 }) if table == start_alias && column == "id"
517 )
518}
519
520fn map_parser_to_core_sort_direction(
524 dir: contextdb_parser::ast::SortDirection,
525) -> Result<contextdb_core::SortDirection> {
526 match dir {
527 contextdb_parser::ast::SortDirection::Asc => Ok(contextdb_core::SortDirection::Asc),
528 contextdb_parser::ast::SortDirection::Desc => Ok(contextdb_core::SortDirection::Desc),
529 contextdb_parser::ast::SortDirection::CosineDistance => Err(Error::ParseError(
530 "CosineDistance is not a valid CREATE INDEX direction".to_string(),
531 )),
532 }
533}
534
535#[allow(dead_code)]
539fn try_plan_index_scan(
540 _table: &str,
541 _where_clause: Option<&Expr>,
542 _indexes: &[contextdb_core::table_meta::IndexDecl],
543) -> Option<crate::plan::PhysicalPlan> {
544 None
545}