1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4 AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5 SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14 match stmt {
15 Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16 name: ct.name.clone(),
17 columns: ct.columns.clone(),
18 unique_constraints: ct.unique_constraints.clone(),
19 immutable: ct.immutable,
20 state_machine: ct.state_machine.clone(),
21 dag_edge_types: ct.dag_edge_types.clone(),
22 propagation_rules: extract_propagation_rules(ct)?,
23 retain: ct.retain.clone(),
24 })),
25 Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
26 table: at.table.clone(),
27 action: at.action.clone(),
28 })),
29 Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
30 Statement::CreateIndex(ci) => Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
31 name: ci.name.clone(),
32 table: ci.table.clone(),
33 columns: ci.columns.clone(),
34 })),
35 Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
36 table: i.table.clone(),
37 columns: i.columns.clone(),
38 values: i.values.clone(),
39 on_conflict: i.on_conflict.clone().map(Into::into),
40 })),
41 Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
42 table: d.table.clone(),
43 where_clause: d.where_clause.clone(),
44 })),
45 Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
46 table: u.table.clone(),
47 assignments: u.assignments.clone(),
48 where_clause: u.where_clause.clone(),
49 })),
50 Statement::Select(sel) => plan_select(sel),
51 Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
52 Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
53 Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
54 Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
55 Statement::SetSyncConflictPolicy(policy) => {
56 Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
57 }
58 Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
59 Statement::Begin | Statement::Commit | Statement::Rollback => {
60 Ok(PhysicalPlan::Pipeline(vec![]))
61 }
62 }
63}
64
65fn extract_propagation_rules(
66 ct: &contextdb_parser::ast::CreateTable,
67) -> Result<Vec<PropagationRule>> {
68 let mut rules = Vec::new();
69
70 for column in &ct.columns {
71 if let Some(fk) = &column.references {
72 for rule in &fk.propagation_rules {
73 if let AstPropagationRule::FkState {
74 trigger_state,
75 target_state,
76 max_depth,
77 abort_on_failure,
78 } = rule
79 {
80 rules.push(PropagationRule::ForeignKey {
81 fk_column: column.name.clone(),
82 referenced_table: fk.table.clone(),
83 referenced_column: fk.column.clone(),
84 trigger_state: trigger_state.clone(),
85 target_state: target_state.clone(),
86 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
87 abort_on_failure: *abort_on_failure,
88 });
89 }
90 }
91 }
92 }
93
94 for rule in &ct.propagation_rules {
95 match rule {
96 AstPropagationRule::EdgeState {
97 edge_type,
98 direction,
99 trigger_state,
100 target_state,
101 max_depth,
102 abort_on_failure,
103 } => {
104 let direction = match direction.to_ascii_uppercase().as_str() {
105 "OUTGOING" => Direction::Outgoing,
106 "INCOMING" => Direction::Incoming,
107 "BOTH" => Direction::Both,
108 other => {
109 return Err(Error::PlanError(format!(
110 "invalid edge direction in propagation rule: {}",
111 other
112 )));
113 }
114 };
115 rules.push(PropagationRule::Edge {
116 edge_type: edge_type.clone(),
117 direction,
118 trigger_state: trigger_state.clone(),
119 target_state: target_state.clone(),
120 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
121 abort_on_failure: *abort_on_failure,
122 });
123 }
124 AstPropagationRule::VectorExclusion { trigger_state } => {
125 rules.push(PropagationRule::VectorExclusion {
126 trigger_state: trigger_state.clone(),
127 });
128 }
129 AstPropagationRule::FkState { .. } => {}
130 }
131 }
132
133 Ok(rules)
134}
135
136fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
137 let mut cte_env = HashMap::new();
138
139 for cte in &sel.ctes {
140 match cte {
141 Cte::MatchCte { name, match_clause } => {
142 cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
143 }
144 Cte::SqlCte { name, query } => {
145 cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
146 }
147 }
148 }
149
150 plan_select_body(&sel.body, &cte_env)
151}
152
153fn plan_select_body(
154 body: &SelectBody,
155 cte_env: &HashMap<String, PhysicalPlan>,
156) -> Result<PhysicalPlan> {
157 let graph_from = body
158 .from
159 .iter()
160 .find(|f| matches!(f, FromItem::GraphTable { .. }));
161
162 let mut current = if let Some(from_item) = graph_from {
163 graph_plan_from_from_item(from_item, cte_env)?
164 } else {
165 let from_item = body.from.iter().find_map(|item| match item {
166 FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
167 FromItem::GraphTable { .. } => None,
168 });
169
170 match from_item {
171 Some((from_table, from_alias)) => {
172 if let Some(cte_plan) = cte_env.get(&from_table) {
173 let mut cte_plan = cte_plan.clone();
174 if body.joins.is_empty()
175 && let Some(where_clause) = &body.where_clause
176 {
177 cte_plan = PhysicalPlan::Filter {
178 input: Box::new(cte_plan),
179 predicate: where_clause.clone(),
180 };
181 }
182 cte_plan
183 } else {
184 PhysicalPlan::Scan {
185 table: from_table,
186 alias: from_alias.clone(),
187 filter: if body.joins.is_empty() {
188 body.where_clause.clone()
189 } else {
190 None
191 },
192 }
193 }
194 }
195 None => PhysicalPlan::Scan {
196 table: "dual".to_string(),
197 alias: None,
198 filter: None,
199 },
200 }
201 };
202
203 if !body.joins.is_empty() {
204 let left_alias = body.from.iter().find_map(|item| match item {
205 FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
206 FromItem::GraphTable { .. } => None,
207 });
208
209 for join in &body.joins {
210 let right = if let Some(cte_plan) = cte_env.get(&join.table) {
211 cte_plan.clone()
212 } else {
213 PhysicalPlan::Scan {
214 table: join.table.clone(),
215 alias: join.alias.clone(),
216 filter: None,
217 }
218 };
219
220 current = PhysicalPlan::Join {
221 left: Box::new(current),
222 right: Box::new(right),
223 condition: join.on.clone(),
224 join_type: match join.join_type {
225 contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
226 contextdb_parser::ast::JoinType::Left => JoinType::Left,
227 },
228 left_alias: left_alias.clone(),
229 right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
230 };
231 }
232
233 if let Some(where_clause) = &body.where_clause {
234 current = PhysicalPlan::Filter {
235 input: Box::new(current),
236 predicate: where_clause.clone(),
237 };
238 }
239 }
240
241 let uses_vector_search = body
242 .order_by
243 .first()
244 .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
245
246 if let Some(order) = body.order_by.first()
247 && matches!(order.direction, SortDirection::CosineDistance)
248 {
249 let k = body.limit.ok_or(Error::UnboundedVectorSearch)?;
250 let vector_table = vector_base_table(¤t)?.ok_or_else(|| {
251 Error::PlanError("unable to resolve physical vector source table".to_string())
252 })?;
253 current = PhysicalPlan::VectorSearch {
254 table: vector_table,
255 column: "embedding".to_string(),
256 query_expr: order.expr.clone(),
257 k,
258 candidates: Some(Box::new(current)),
259 };
260 }
261
262 if !body.order_by.is_empty() && !uses_vector_search {
263 current = PhysicalPlan::Sort {
264 input: Box::new(current),
265 keys: body
266 .order_by
267 .iter()
268 .map(|item| SortKey {
269 expr: item.expr.clone(),
270 direction: item.direction,
271 })
272 .collect(),
273 };
274 }
275
276 let is_select_star = matches!(
277 body.columns.as_slice(),
278 [contextdb_parser::ast::SelectColumn {
279 expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
280 alias: None
281 }] if column == "*"
282 );
283 if !is_select_star {
284 current = PhysicalPlan::Project {
285 input: Box::new(current),
286 columns: body
287 .columns
288 .iter()
289 .map(|column| ProjectColumn {
290 expr: column.expr.clone(),
291 alias: column.alias.clone(),
292 })
293 .collect(),
294 };
295 }
296
297 if body.distinct {
298 current = PhysicalPlan::Distinct {
299 input: Box::new(current),
300 };
301 }
302
303 if let Some(limit) = body.limit
304 && !uses_vector_search
305 {
306 current = PhysicalPlan::Limit {
307 input: Box::new(current),
308 count: limit,
309 };
310 }
311
312 Ok(current)
313}
314
315fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
316 match plan {
317 PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
318 Ok(Some(table.clone()))
319 }
320 PhysicalPlan::Filter { input, .. }
321 | PhysicalPlan::Project { input, .. }
322 | PhysicalPlan::Distinct { input }
323 | PhysicalPlan::Sort { input, .. }
324 | PhysicalPlan::Limit { input, .. }
325 | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
326 PhysicalPlan::Join { left, right, .. } => {
327 let left_table = vector_base_table(left)?;
328 let right_table = vector_base_table(right)?;
329 match (left_table, right_table) {
330 (Some(left), Some(right)) if left == right => Ok(Some(left)),
331 (Some(_), Some(_)) => Err(Error::PlanError(
332 "ambiguous physical vector source table in join".to_string(),
333 )),
334 (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
335 (None, None) => Ok(None),
336 }
337 }
338 PhysicalPlan::Pipeline(plans) => {
339 for plan in plans.iter().rev() {
340 if let Some(table) = vector_base_table(plan)? {
341 return Ok(Some(table));
342 }
343 }
344 Ok(None)
345 }
346 PhysicalPlan::GraphBfs { .. }
347 | PhysicalPlan::CteRef { .. }
348 | PhysicalPlan::Union { .. } => Ok(None),
349 _ => Ok(None),
350 }
351}
352
353fn graph_plan_from_from_item(
354 from_item: &FromItem,
355 cte_env: &HashMap<String, PhysicalPlan>,
356) -> Result<PhysicalPlan> {
357 match from_item {
358 FromItem::GraphTable {
359 match_clause,
360 columns,
361 ..
362 } => {
363 let bfs = graph_bfs_from_match(match_clause, cte_env)?;
364 if columns.is_empty() {
365 Ok(bfs)
366 } else {
367 Ok(PhysicalPlan::Project {
368 input: Box::new(bfs),
369 columns: columns
370 .iter()
371 .map(|c| ProjectColumn {
372 expr: c.expr.clone(),
373 alias: Some(c.alias.clone()),
374 })
375 .collect(),
376 })
377 }
378 }
379 FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
380 table: name.clone(),
381 alias: None,
382 filter: None,
383 }),
384 }
385}
386
387fn graph_bfs_from_match(
388 match_clause: &contextdb_parser::ast::MatchClause,
389 cte_env: &HashMap<String, PhysicalPlan>,
390) -> Result<PhysicalPlan> {
391 let steps = match_clause
392 .pattern
393 .edges
394 .iter()
395 .map(|step| {
396 let max_depth = if step.max_hops == 0 {
397 DEFAULT_MATCH_DEPTH
398 } else {
399 step.max_hops
400 };
401 if max_depth > ENGINE_MAX_BFS_DEPTH {
402 return Err(Error::BfsDepthExceeded(max_depth));
403 }
404
405 Ok(GraphStepPlan {
406 edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
407 direction: match step.direction {
408 contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
409 contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
410 contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
411 },
412 min_depth: step.min_hops.max(1),
413 max_depth,
414 target_alias: step.target.alias.clone(),
415 })
416 })
417 .collect::<Result<Vec<_>>>()?;
418 if steps.is_empty() {
419 return Err(Error::PlanError(
420 "MATCH must include at least one edge".into(),
421 ));
422 }
423
424 Ok(PhysicalPlan::GraphBfs {
425 start_alias: match_clause.pattern.start.alias.clone(),
426 start_expr: extract_graph_start_expr(match_clause)?,
427 start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
428 steps,
429 filter: match_clause.where_clause.clone(),
430 })
431}
432
433fn extract_graph_start_candidates(
434 match_clause: &MatchClause,
435 cte_env: &HashMap<String, PhysicalPlan>,
436) -> Result<Option<Box<PhysicalPlan>>> {
437 let Some(where_clause) = &match_clause.where_clause else {
438 return Ok(None);
439 };
440 find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
441}
442
443fn find_graph_start_candidates(
444 expr: &Expr,
445 start_alias: &str,
446 cte_env: &HashMap<String, PhysicalPlan>,
447) -> Result<Option<Box<PhysicalPlan>>> {
448 match expr {
449 Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
450 Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
451 }
452 Expr::BinaryOp { left, right, .. } => {
453 if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
454 return Ok(Some(plan));
455 }
456 find_graph_start_candidates(right, start_alias, cte_env)
457 }
458 Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
459 _ => Ok(None),
460 }
461}
462
463fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
464 let start_alias = &match_clause.pattern.start.alias;
465 if let Some(where_clause) = &match_clause.where_clause
466 && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
467 {
468 return Ok(expr);
469 }
470
471 Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
472 table: None,
473 column: start_alias.clone(),
474 }))
475}
476
477fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
478 match expr {
479 Expr::BinaryOp {
480 left,
481 op: BinOp::Eq,
482 right,
483 } => {
484 if is_graph_start_id_ref(left, start_alias) {
485 Some((**right).clone())
486 } else if is_graph_start_id_ref(right, start_alias) {
487 Some((**left).clone())
488 } else {
489 None
490 }
491 }
492 Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
493 .or_else(|| find_graph_start_expr(right, start_alias)),
494 Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
495 _ => None,
496 }
497}
498
499fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
500 matches!(
501 expr,
502 Expr::Column(contextdb_parser::ast::ColumnRef {
503 table: Some(table),
504 column
505 }) if table == start_alias && column == "id"
506 )
507}