1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4 AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5 SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14 match stmt {
15 Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16 name: ct.name.clone(),
17 columns: ct.columns.clone(),
18 unique_constraints: ct.unique_constraints.clone(),
19 immutable: ct.immutable,
20 state_machine: ct.state_machine.clone(),
21 dag_edge_types: ct.dag_edge_types.clone(),
22 propagation_rules: extract_propagation_rules(ct)?,
23 retain: ct.retain.clone(),
24 })),
25 Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
26 table: at.table.clone(),
27 action: at.action.clone(),
28 })),
29 Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
30 Statement::CreateIndex(ci) => {
31 let mut columns = Vec::with_capacity(ci.columns.len());
32 for (col, dir) in &ci.columns {
33 columns.push((col.clone(), map_parser_to_core_sort_direction(*dir)?));
34 }
35 Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
36 name: ci.name.clone(),
37 table: ci.table.clone(),
38 columns,
39 }))
40 }
41 Statement::DropIndex(di) => Ok(PhysicalPlan::DropIndex(DropIndexPlan {
42 name: di.name.clone(),
43 table: di.table.clone(),
44 if_exists: di.if_exists,
45 })),
46 Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
47 table: i.table.clone(),
48 columns: i.columns.clone(),
49 values: i.values.clone(),
50 on_conflict: i.on_conflict.clone().map(Into::into),
51 })),
52 Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
53 table: d.table.clone(),
54 where_clause: d.where_clause.clone(),
55 })),
56 Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
57 table: u.table.clone(),
58 assignments: u.assignments.clone(),
59 where_clause: u.where_clause.clone(),
60 })),
61 Statement::Select(sel) => plan_select(sel),
62 Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
63 Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
64 Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
65 Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
66 Statement::SetSyncConflictPolicy(policy) => {
67 Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
68 }
69 Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
70 Statement::ShowVectorIndexes => Ok(PhysicalPlan::ShowVectorIndexes),
71 Statement::Begin | Statement::Commit | Statement::Rollback => {
72 Ok(PhysicalPlan::Pipeline(vec![]))
73 }
74 }
75}
76
77fn extract_propagation_rules(
78 ct: &contextdb_parser::ast::CreateTable,
79) -> Result<Vec<PropagationRule>> {
80 let mut rules = Vec::new();
81
82 for column in &ct.columns {
83 if let Some(fk) = &column.references {
84 for rule in &fk.propagation_rules {
85 if let AstPropagationRule::FkState {
86 trigger_state,
87 target_state,
88 max_depth,
89 abort_on_failure,
90 } = rule
91 {
92 rules.push(PropagationRule::ForeignKey {
93 fk_column: column.name.clone(),
94 referenced_table: fk.table.clone(),
95 referenced_column: fk.column.clone(),
96 trigger_state: trigger_state.clone(),
97 target_state: target_state.clone(),
98 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
99 abort_on_failure: *abort_on_failure,
100 });
101 }
102 }
103 }
104 }
105
106 for rule in &ct.propagation_rules {
107 match rule {
108 AstPropagationRule::EdgeState {
109 edge_type,
110 direction,
111 trigger_state,
112 target_state,
113 max_depth,
114 abort_on_failure,
115 } => {
116 let direction = match direction.to_ascii_uppercase().as_str() {
117 "OUTGOING" => Direction::Outgoing,
118 "INCOMING" => Direction::Incoming,
119 "BOTH" => Direction::Both,
120 other => {
121 return Err(Error::PlanError(format!(
122 "invalid edge direction in propagation rule: {}",
123 other
124 )));
125 }
126 };
127 rules.push(PropagationRule::Edge {
128 edge_type: edge_type.clone(),
129 direction,
130 trigger_state: trigger_state.clone(),
131 target_state: target_state.clone(),
132 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
133 abort_on_failure: *abort_on_failure,
134 });
135 }
136 AstPropagationRule::VectorExclusion { trigger_state } => {
137 rules.push(PropagationRule::VectorExclusion {
138 trigger_state: trigger_state.clone(),
139 });
140 }
141 AstPropagationRule::FkState { .. } => {}
142 }
143 }
144
145 Ok(rules)
146}
147
148fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
149 let mut cte_env = HashMap::new();
150
151 for cte in &sel.ctes {
152 match cte {
153 Cte::MatchCte { name, match_clause } => {
154 cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
155 }
156 Cte::SqlCte { name, query } => {
157 cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
158 }
159 }
160 }
161
162 plan_select_body(&sel.body, &cte_env)
163}
164
165fn plan_select_body(
166 body: &SelectBody,
167 cte_env: &HashMap<String, PhysicalPlan>,
168) -> Result<PhysicalPlan> {
169 let graph_from = body
170 .from
171 .iter()
172 .find(|f| matches!(f, FromItem::GraphTable { .. }));
173
174 let mut current = if let Some(from_item) = graph_from {
175 graph_plan_from_from_item(from_item, cte_env)?
176 } else {
177 let from_item = body.from.iter().find_map(|item| match item {
178 FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
179 FromItem::GraphTable { .. } => None,
180 });
181
182 match from_item {
183 Some((from_table, from_alias)) => {
184 if let Some(cte_plan) = cte_env.get(&from_table) {
185 let mut cte_plan = cte_plan.clone();
186 if body.joins.is_empty()
187 && let Some(where_clause) = &body.where_clause
188 {
189 cte_plan = PhysicalPlan::Filter {
190 input: Box::new(cte_plan),
191 predicate: where_clause.clone(),
192 };
193 }
194 cte_plan
195 } else {
196 PhysicalPlan::Scan {
197 table: from_table,
198 alias: from_alias.clone(),
199 filter: if body.joins.is_empty() {
200 body.where_clause.clone()
201 } else {
202 None
203 },
204 }
205 }
206 }
207 None => PhysicalPlan::Scan {
208 table: "dual".to_string(),
209 alias: None,
210 filter: None,
211 },
212 }
213 };
214
215 if !body.joins.is_empty() {
216 let left_alias = body.from.iter().find_map(|item| match item {
217 FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
218 FromItem::GraphTable { .. } => None,
219 });
220
221 for join in &body.joins {
222 let right = if let Some(cte_plan) = cte_env.get(&join.table) {
223 cte_plan.clone()
224 } else {
225 PhysicalPlan::Scan {
226 table: join.table.clone(),
227 alias: join.alias.clone(),
228 filter: None,
229 }
230 };
231
232 current = PhysicalPlan::Join {
233 left: Box::new(current),
234 right: Box::new(right),
235 condition: join.on.clone(),
236 join_type: match join.join_type {
237 contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
238 contextdb_parser::ast::JoinType::Left => JoinType::Left,
239 },
240 left_alias: left_alias.clone(),
241 right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
242 };
243 }
244
245 if let Some(where_clause) = &body.where_clause {
246 current = PhysicalPlan::Filter {
247 input: Box::new(current),
248 predicate: where_clause.clone(),
249 };
250 }
251 }
252
253 let uses_vector_search = body
254 .order_by
255 .first()
256 .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
257
258 if body.use_rank.is_some() && !uses_vector_search {
259 return Err(Error::UseRankRequiresVectorOrder);
260 }
261
262 if let Some(order) = body.order_by.first()
263 && matches!(order.direction, SortDirection::CosineDistance)
264 {
265 let k = body.limit.ok_or_else(|| {
266 if body.use_rank.is_some() {
267 Error::UseRankRequiresLimit
268 } else {
269 Error::UnboundedVectorSearch
270 }
271 })?;
272 let vector_table = vector_base_table(¤t)?.ok_or_else(|| {
273 Error::PlanError("unable to resolve physical vector source table".to_string())
274 })?;
275 let (column, query_expr) = vector_search_parts(&order.expr)?;
276 current = PhysicalPlan::VectorSearch {
277 table: vector_table,
278 column,
279 query_expr,
280 k,
281 candidates: Some(Box::new(current)),
282 sort_key: body.use_rank.clone(),
283 };
284 }
285
286 if !body.order_by.is_empty() && !uses_vector_search {
287 current = PhysicalPlan::Sort {
288 input: Box::new(current),
289 keys: body
290 .order_by
291 .iter()
292 .map(|item| SortKey {
293 expr: item.expr.clone(),
294 direction: item.direction,
295 })
296 .collect(),
297 };
298 }
299
300 let is_select_star = matches!(
301 body.columns.as_slice(),
302 [contextdb_parser::ast::SelectColumn {
303 expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
304 alias: None
305 }] if column == "*"
306 );
307 if !is_select_star {
308 current = PhysicalPlan::Project {
309 input: Box::new(current),
310 columns: body
311 .columns
312 .iter()
313 .map(|column| ProjectColumn {
314 expr: column.expr.clone(),
315 alias: column.alias.clone(),
316 })
317 .collect(),
318 };
319 }
320
321 if body.distinct {
322 current = PhysicalPlan::Distinct {
323 input: Box::new(current),
324 };
325 }
326
327 if let Some(limit) = body.limit
328 && !uses_vector_search
329 {
330 current = PhysicalPlan::Limit {
331 input: Box::new(current),
332 count: limit,
333 };
334 }
335
336 Ok(current)
337}
338
339fn vector_search_parts(expr: &Expr) -> Result<(String, Expr)> {
340 match expr {
341 Expr::CosineDistance { left, right } => match left.as_ref() {
342 Expr::Column(column) => Ok((column.column.clone(), right.as_ref().clone())),
343 _ => Err(Error::PlanError(
344 "left side of vector distance must be a column".to_string(),
345 )),
346 },
347 _ => Err(Error::PlanError(
348 "vector search requires a cosine distance expression".to_string(),
349 )),
350 }
351}
352
353fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
354 match plan {
355 PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
356 Ok(Some(table.clone()))
357 }
358 PhysicalPlan::Filter { input, .. }
359 | PhysicalPlan::Project { input, .. }
360 | PhysicalPlan::Distinct { input }
361 | PhysicalPlan::Sort { input, .. }
362 | PhysicalPlan::Limit { input, .. }
363 | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
364 PhysicalPlan::Join { left, right, .. } => {
365 let left_table = vector_base_table(left)?;
366 let right_table = vector_base_table(right)?;
367 match (left_table, right_table) {
368 (Some(left), Some(right)) if left == right => Ok(Some(left)),
369 (Some(_), Some(_)) => Err(Error::PlanError(
370 "ambiguous physical vector source table in join".to_string(),
371 )),
372 (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
373 (None, None) => Ok(None),
374 }
375 }
376 PhysicalPlan::Pipeline(plans) => {
377 for plan in plans.iter().rev() {
378 if let Some(table) = vector_base_table(plan)? {
379 return Ok(Some(table));
380 }
381 }
382 Ok(None)
383 }
384 PhysicalPlan::GraphBfs { .. }
385 | PhysicalPlan::CteRef { .. }
386 | PhysicalPlan::Union { .. } => Ok(None),
387 _ => Ok(None),
388 }
389}
390
391fn graph_plan_from_from_item(
392 from_item: &FromItem,
393 cte_env: &HashMap<String, PhysicalPlan>,
394) -> Result<PhysicalPlan> {
395 match from_item {
396 FromItem::GraphTable {
397 match_clause,
398 columns,
399 ..
400 } => {
401 let bfs = graph_bfs_from_match(match_clause, cte_env)?;
402 if columns.is_empty() {
403 Ok(bfs)
404 } else {
405 Ok(PhysicalPlan::Project {
406 input: Box::new(bfs),
407 columns: columns
408 .iter()
409 .map(|c| ProjectColumn {
410 expr: c.expr.clone(),
411 alias: Some(c.alias.clone()),
412 })
413 .collect(),
414 })
415 }
416 }
417 FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
418 table: name.clone(),
419 alias: None,
420 filter: None,
421 }),
422 }
423}
424
425fn graph_bfs_from_match(
426 match_clause: &contextdb_parser::ast::MatchClause,
427 cte_env: &HashMap<String, PhysicalPlan>,
428) -> Result<PhysicalPlan> {
429 let steps = match_clause
430 .pattern
431 .edges
432 .iter()
433 .map(|step| {
434 let max_depth = if step.max_hops == 0 {
435 DEFAULT_MATCH_DEPTH
436 } else {
437 step.max_hops
438 };
439 if max_depth > ENGINE_MAX_BFS_DEPTH {
440 return Err(Error::BfsDepthExceeded(max_depth));
441 }
442
443 Ok(GraphStepPlan {
444 edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
445 direction: match step.direction {
446 contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
447 contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
448 contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
449 },
450 min_depth: step.min_hops.max(1),
451 max_depth,
452 target_alias: step.target.alias.clone(),
453 })
454 })
455 .collect::<Result<Vec<_>>>()?;
456 if steps.is_empty() {
457 return Err(Error::PlanError(
458 "MATCH must include at least one edge".into(),
459 ));
460 }
461
462 Ok(PhysicalPlan::GraphBfs {
463 start_alias: match_clause.pattern.start.alias.clone(),
464 start_expr: extract_graph_start_expr(match_clause)?,
465 start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
466 steps,
467 filter: match_clause.where_clause.clone(),
468 })
469}
470
471fn extract_graph_start_candidates(
472 match_clause: &MatchClause,
473 cte_env: &HashMap<String, PhysicalPlan>,
474) -> Result<Option<Box<PhysicalPlan>>> {
475 let Some(where_clause) = &match_clause.where_clause else {
476 return Ok(None);
477 };
478 find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
479}
480
481fn find_graph_start_candidates(
482 expr: &Expr,
483 start_alias: &str,
484 cte_env: &HashMap<String, PhysicalPlan>,
485) -> Result<Option<Box<PhysicalPlan>>> {
486 match expr {
487 Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
488 Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
489 }
490 Expr::BinaryOp { left, right, .. } => {
491 if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
492 return Ok(Some(plan));
493 }
494 find_graph_start_candidates(right, start_alias, cte_env)
495 }
496 Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
497 _ => Ok(None),
498 }
499}
500
501fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
502 let start_alias = &match_clause.pattern.start.alias;
503 if let Some(where_clause) = &match_clause.where_clause
504 && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
505 {
506 return Ok(expr);
507 }
508
509 Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
510 table: None,
511 column: start_alias.clone(),
512 }))
513}
514
515fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
516 match expr {
517 Expr::BinaryOp {
518 left,
519 op: BinOp::Eq,
520 right,
521 } => {
522 if is_graph_start_id_ref(left, start_alias) {
523 Some((**right).clone())
524 } else if is_graph_start_id_ref(right, start_alias) {
525 Some((**left).clone())
526 } else {
527 None
528 }
529 }
530 Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
531 .or_else(|| find_graph_start_expr(right, start_alias)),
532 Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
533 _ => None,
534 }
535}
536
537fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
538 matches!(
539 expr,
540 Expr::Column(contextdb_parser::ast::ColumnRef {
541 table: Some(table),
542 column
543 }) if table == start_alias && column == "id"
544 )
545}
546
547fn map_parser_to_core_sort_direction(
551 dir: contextdb_parser::ast::SortDirection,
552) -> Result<contextdb_core::SortDirection> {
553 match dir {
554 contextdb_parser::ast::SortDirection::Asc => Ok(contextdb_core::SortDirection::Asc),
555 contextdb_parser::ast::SortDirection::Desc => Ok(contextdb_core::SortDirection::Desc),
556 contextdb_parser::ast::SortDirection::CosineDistance => Err(Error::ParseError(
557 "CosineDistance is not a valid CREATE INDEX direction".to_string(),
558 )),
559 }
560}
561
562#[allow(dead_code)]
566fn try_plan_index_scan(
567 _table: &str,
568 _where_clause: Option<&Expr>,
569 _indexes: &[contextdb_core::table_meta::IndexDecl],
570) -> Option<crate::plan::PhysicalPlan> {
571 None
572}