1use crate::plan::*;
2use contextdb_core::{Direction, Error, PropagationRule, Result};
3use contextdb_parser::ast::{
4 AstPropagationRule, BinOp, Cte, Expr, FromItem, MatchClause, SelectBody, SelectStatement,
5 SortDirection, Statement,
6};
7use std::collections::HashMap;
8
9const DEFAULT_MATCH_DEPTH: u32 = 5;
10const ENGINE_MAX_BFS_DEPTH: u32 = 10;
11const DEFAULT_PROPAGATION_MAX_DEPTH: u32 = 10;
12
13pub fn plan(stmt: &Statement) -> Result<PhysicalPlan> {
14 match stmt {
15 Statement::CreateTable(ct) => Ok(PhysicalPlan::CreateTable(CreateTablePlan {
16 name: ct.name.clone(),
17 columns: ct.columns.clone(),
18 immutable: ct.immutable,
19 state_machine: ct.state_machine.clone(),
20 dag_edge_types: ct.dag_edge_types.clone(),
21 propagation_rules: extract_propagation_rules(ct)?,
22 retain: ct.retain.clone(),
23 })),
24 Statement::AlterTable(at) => Ok(PhysicalPlan::AlterTable(AlterTablePlan {
25 table: at.table.clone(),
26 action: at.action.clone(),
27 })),
28 Statement::DropTable(dt) => Ok(PhysicalPlan::DropTable(dt.name.clone())),
29 Statement::CreateIndex(ci) => Ok(PhysicalPlan::CreateIndex(CreateIndexPlan {
30 name: ci.name.clone(),
31 table: ci.table.clone(),
32 columns: ci.columns.clone(),
33 })),
34 Statement::Insert(i) => Ok(PhysicalPlan::Insert(InsertPlan {
35 table: i.table.clone(),
36 columns: i.columns.clone(),
37 values: i.values.clone(),
38 on_conflict: i.on_conflict.clone().map(Into::into),
39 })),
40 Statement::Delete(d) => Ok(PhysicalPlan::Delete(DeletePlan {
41 table: d.table.clone(),
42 where_clause: d.where_clause.clone(),
43 })),
44 Statement::Update(u) => Ok(PhysicalPlan::Update(UpdatePlan {
45 table: u.table.clone(),
46 assignments: u.assignments.clone(),
47 where_clause: u.where_clause.clone(),
48 })),
49 Statement::Select(sel) => plan_select(sel),
50 Statement::SetMemoryLimit(val) => Ok(PhysicalPlan::SetMemoryLimit(val.clone())),
51 Statement::ShowMemoryLimit => Ok(PhysicalPlan::ShowMemoryLimit),
52 Statement::SetDiskLimit(val) => Ok(PhysicalPlan::SetDiskLimit(val.clone())),
53 Statement::ShowDiskLimit => Ok(PhysicalPlan::ShowDiskLimit),
54 Statement::SetSyncConflictPolicy(policy) => {
55 Ok(PhysicalPlan::SetSyncConflictPolicy(policy.clone()))
56 }
57 Statement::ShowSyncConflictPolicy => Ok(PhysicalPlan::ShowSyncConflictPolicy),
58 Statement::Begin | Statement::Commit | Statement::Rollback => {
59 Ok(PhysicalPlan::Pipeline(vec![]))
60 }
61 }
62}
63
64fn extract_propagation_rules(
65 ct: &contextdb_parser::ast::CreateTable,
66) -> Result<Vec<PropagationRule>> {
67 let mut rules = Vec::new();
68
69 for column in &ct.columns {
70 if let Some(fk) = &column.references {
71 for rule in &fk.propagation_rules {
72 if let AstPropagationRule::FkState {
73 trigger_state,
74 target_state,
75 max_depth,
76 abort_on_failure,
77 } = rule
78 {
79 rules.push(PropagationRule::ForeignKey {
80 fk_column: column.name.clone(),
81 referenced_table: fk.table.clone(),
82 referenced_column: fk.column.clone(),
83 trigger_state: trigger_state.clone(),
84 target_state: target_state.clone(),
85 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
86 abort_on_failure: *abort_on_failure,
87 });
88 }
89 }
90 }
91 }
92
93 for rule in &ct.propagation_rules {
94 match rule {
95 AstPropagationRule::EdgeState {
96 edge_type,
97 direction,
98 trigger_state,
99 target_state,
100 max_depth,
101 abort_on_failure,
102 } => {
103 let direction = match direction.to_ascii_uppercase().as_str() {
104 "OUTGOING" => Direction::Outgoing,
105 "INCOMING" => Direction::Incoming,
106 "BOTH" => Direction::Both,
107 other => {
108 return Err(Error::PlanError(format!(
109 "invalid edge direction in propagation rule: {}",
110 other
111 )));
112 }
113 };
114 rules.push(PropagationRule::Edge {
115 edge_type: edge_type.clone(),
116 direction,
117 trigger_state: trigger_state.clone(),
118 target_state: target_state.clone(),
119 max_depth: max_depth.unwrap_or(DEFAULT_PROPAGATION_MAX_DEPTH),
120 abort_on_failure: *abort_on_failure,
121 });
122 }
123 AstPropagationRule::VectorExclusion { trigger_state } => {
124 rules.push(PropagationRule::VectorExclusion {
125 trigger_state: trigger_state.clone(),
126 });
127 }
128 AstPropagationRule::FkState { .. } => {}
129 }
130 }
131
132 Ok(rules)
133}
134
135fn plan_select(sel: &SelectStatement) -> Result<PhysicalPlan> {
136 let mut cte_env = HashMap::new();
137
138 for cte in &sel.ctes {
139 match cte {
140 Cte::MatchCte { name, match_clause } => {
141 cte_env.insert(name.clone(), graph_bfs_from_match(match_clause, &cte_env)?);
142 }
143 Cte::SqlCte { name, query } => {
144 cte_env.insert(name.clone(), plan_select_body(query, &cte_env)?);
145 }
146 }
147 }
148
149 plan_select_body(&sel.body, &cte_env)
150}
151
152fn plan_select_body(
153 body: &SelectBody,
154 cte_env: &HashMap<String, PhysicalPlan>,
155) -> Result<PhysicalPlan> {
156 let graph_from = body
157 .from
158 .iter()
159 .find(|f| matches!(f, FromItem::GraphTable { .. }));
160
161 let mut current = if let Some(from_item) = graph_from {
162 graph_plan_from_from_item(from_item, cte_env)?
163 } else {
164 let from_item = body.from.iter().find_map(|item| match item {
165 FromItem::Table { name, alias } => Some((name.clone(), alias.clone())),
166 FromItem::GraphTable { .. } => None,
167 });
168
169 match from_item {
170 Some((from_table, from_alias)) => {
171 if let Some(cte_plan) = cte_env.get(&from_table) {
172 let mut cte_plan = cte_plan.clone();
173 if body.joins.is_empty()
174 && let Some(where_clause) = &body.where_clause
175 {
176 cte_plan = PhysicalPlan::Filter {
177 input: Box::new(cte_plan),
178 predicate: where_clause.clone(),
179 };
180 }
181 cte_plan
182 } else {
183 PhysicalPlan::Scan {
184 table: from_table,
185 alias: from_alias.clone(),
186 filter: if body.joins.is_empty() {
187 body.where_clause.clone()
188 } else {
189 None
190 },
191 }
192 }
193 }
194 None => PhysicalPlan::Scan {
195 table: "dual".to_string(),
196 alias: None,
197 filter: None,
198 },
199 }
200 };
201
202 if !body.joins.is_empty() {
203 let left_alias = body.from.iter().find_map(|item| match item {
204 FromItem::Table { alias, name } => alias.clone().or_else(|| Some(name.clone())),
205 FromItem::GraphTable { .. } => None,
206 });
207
208 for join in &body.joins {
209 let right = if let Some(cte_plan) = cte_env.get(&join.table) {
210 cte_plan.clone()
211 } else {
212 PhysicalPlan::Scan {
213 table: join.table.clone(),
214 alias: join.alias.clone(),
215 filter: None,
216 }
217 };
218
219 current = PhysicalPlan::Join {
220 left: Box::new(current),
221 right: Box::new(right),
222 condition: join.on.clone(),
223 join_type: match join.join_type {
224 contextdb_parser::ast::JoinType::Inner => JoinType::Inner,
225 contextdb_parser::ast::JoinType::Left => JoinType::Left,
226 },
227 left_alias: left_alias.clone(),
228 right_alias: join.alias.clone().or_else(|| Some(join.table.clone())),
229 };
230 }
231
232 if let Some(where_clause) = &body.where_clause {
233 current = PhysicalPlan::Filter {
234 input: Box::new(current),
235 predicate: where_clause.clone(),
236 };
237 }
238 }
239
240 let uses_vector_search = body
241 .order_by
242 .first()
243 .is_some_and(|order| matches!(order.direction, SortDirection::CosineDistance));
244
245 if let Some(order) = body.order_by.first()
246 && matches!(order.direction, SortDirection::CosineDistance)
247 {
248 let k = body.limit.ok_or(Error::UnboundedVectorSearch)?;
249 let vector_table = vector_base_table(¤t)?.ok_or_else(|| {
250 Error::PlanError("unable to resolve physical vector source table".to_string())
251 })?;
252 current = PhysicalPlan::VectorSearch {
253 table: vector_table,
254 column: "embedding".to_string(),
255 query_expr: order.expr.clone(),
256 k,
257 candidates: Some(Box::new(current)),
258 };
259 }
260
261 if !body.order_by.is_empty() && !uses_vector_search {
262 current = PhysicalPlan::Sort {
263 input: Box::new(current),
264 keys: body
265 .order_by
266 .iter()
267 .map(|item| SortKey {
268 expr: item.expr.clone(),
269 direction: item.direction,
270 })
271 .collect(),
272 };
273 }
274
275 let is_select_star = matches!(
276 body.columns.as_slice(),
277 [contextdb_parser::ast::SelectColumn {
278 expr: Expr::Column(contextdb_parser::ast::ColumnRef { table: None, column }),
279 alias: None
280 }] if column == "*"
281 );
282 if !is_select_star {
283 current = PhysicalPlan::Project {
284 input: Box::new(current),
285 columns: body
286 .columns
287 .iter()
288 .map(|column| ProjectColumn {
289 expr: column.expr.clone(),
290 alias: column.alias.clone(),
291 })
292 .collect(),
293 };
294 }
295
296 if body.distinct {
297 current = PhysicalPlan::Distinct {
298 input: Box::new(current),
299 };
300 }
301
302 if let Some(limit) = body.limit
303 && !uses_vector_search
304 {
305 current = PhysicalPlan::Limit {
306 input: Box::new(current),
307 count: limit,
308 };
309 }
310
311 Ok(current)
312}
313
314fn vector_base_table(plan: &PhysicalPlan) -> Result<Option<String>> {
315 match plan {
316 PhysicalPlan::Scan { table, .. } | PhysicalPlan::IndexScan { table, .. } => {
317 Ok(Some(table.clone()))
318 }
319 PhysicalPlan::Filter { input, .. }
320 | PhysicalPlan::Project { input, .. }
321 | PhysicalPlan::Distinct { input }
322 | PhysicalPlan::Sort { input, .. }
323 | PhysicalPlan::Limit { input, .. }
324 | PhysicalPlan::MaterializeCte { input, .. } => vector_base_table(input),
325 PhysicalPlan::Join { left, right, .. } => {
326 let left_table = vector_base_table(left)?;
327 let right_table = vector_base_table(right)?;
328 match (left_table, right_table) {
329 (Some(left), Some(right)) if left == right => Ok(Some(left)),
330 (Some(_), Some(_)) => Err(Error::PlanError(
331 "ambiguous physical vector source table in join".to_string(),
332 )),
333 (Some(table), None) | (None, Some(table)) => Ok(Some(table)),
334 (None, None) => Ok(None),
335 }
336 }
337 PhysicalPlan::Pipeline(plans) => {
338 for plan in plans.iter().rev() {
339 if let Some(table) = vector_base_table(plan)? {
340 return Ok(Some(table));
341 }
342 }
343 Ok(None)
344 }
345 PhysicalPlan::GraphBfs { .. }
346 | PhysicalPlan::CteRef { .. }
347 | PhysicalPlan::Union { .. } => Ok(None),
348 _ => Ok(None),
349 }
350}
351
352fn graph_plan_from_from_item(
353 from_item: &FromItem,
354 cte_env: &HashMap<String, PhysicalPlan>,
355) -> Result<PhysicalPlan> {
356 match from_item {
357 FromItem::GraphTable {
358 match_clause,
359 columns,
360 ..
361 } => {
362 let bfs = graph_bfs_from_match(match_clause, cte_env)?;
363 if columns.is_empty() {
364 Ok(bfs)
365 } else {
366 Ok(PhysicalPlan::Project {
367 input: Box::new(bfs),
368 columns: columns
369 .iter()
370 .map(|c| ProjectColumn {
371 expr: c.expr.clone(),
372 alias: Some(c.alias.clone()),
373 })
374 .collect(),
375 })
376 }
377 }
378 FromItem::Table { name, .. } => Ok(PhysicalPlan::Scan {
379 table: name.clone(),
380 alias: None,
381 filter: None,
382 }),
383 }
384}
385
386fn graph_bfs_from_match(
387 match_clause: &contextdb_parser::ast::MatchClause,
388 cte_env: &HashMap<String, PhysicalPlan>,
389) -> Result<PhysicalPlan> {
390 let steps = match_clause
391 .pattern
392 .edges
393 .iter()
394 .map(|step| {
395 let max_depth = if step.max_hops == 0 {
396 DEFAULT_MATCH_DEPTH
397 } else {
398 step.max_hops
399 };
400 if max_depth > ENGINE_MAX_BFS_DEPTH {
401 return Err(Error::BfsDepthExceeded(max_depth));
402 }
403
404 Ok(GraphStepPlan {
405 edge_types: step.edge_type.clone().map(|t| vec![t]).unwrap_or_default(),
406 direction: match step.direction {
407 contextdb_parser::ast::EdgeDirection::Outgoing => Direction::Outgoing,
408 contextdb_parser::ast::EdgeDirection::Incoming => Direction::Incoming,
409 contextdb_parser::ast::EdgeDirection::Both => Direction::Both,
410 },
411 min_depth: step.min_hops.max(1),
412 max_depth,
413 target_alias: step.target.alias.clone(),
414 })
415 })
416 .collect::<Result<Vec<_>>>()?;
417 if steps.is_empty() {
418 return Err(Error::PlanError(
419 "MATCH must include at least one edge".into(),
420 ));
421 }
422
423 Ok(PhysicalPlan::GraphBfs {
424 start_alias: match_clause.pattern.start.alias.clone(),
425 start_expr: extract_graph_start_expr(match_clause)?,
426 start_candidates: extract_graph_start_candidates(match_clause, cte_env)?,
427 steps,
428 filter: match_clause.where_clause.clone(),
429 })
430}
431
432fn extract_graph_start_candidates(
433 match_clause: &MatchClause,
434 cte_env: &HashMap<String, PhysicalPlan>,
435) -> Result<Option<Box<PhysicalPlan>>> {
436 let Some(where_clause) = &match_clause.where_clause else {
437 return Ok(None);
438 };
439 find_graph_start_candidates(where_clause, &match_clause.pattern.start.alias, cte_env)
440}
441
442fn find_graph_start_candidates(
443 expr: &Expr,
444 start_alias: &str,
445 cte_env: &HashMap<String, PhysicalPlan>,
446) -> Result<Option<Box<PhysicalPlan>>> {
447 match expr {
448 Expr::InSubquery { expr, subquery, .. } if is_graph_start_id_ref(expr, start_alias) => {
449 Ok(Some(Box::new(plan_select_body(subquery, cte_env)?)))
450 }
451 Expr::BinaryOp { left, right, .. } => {
452 if let Some(plan) = find_graph_start_candidates(left, start_alias, cte_env)? {
453 return Ok(Some(plan));
454 }
455 find_graph_start_candidates(right, start_alias, cte_env)
456 }
457 Expr::UnaryOp { operand, .. } => find_graph_start_candidates(operand, start_alias, cte_env),
458 _ => Ok(None),
459 }
460}
461
462fn extract_graph_start_expr(match_clause: &MatchClause) -> Result<Expr> {
463 let start_alias = &match_clause.pattern.start.alias;
464 if let Some(where_clause) = &match_clause.where_clause
465 && let Some(expr) = find_graph_start_expr(where_clause, start_alias)
466 {
467 return Ok(expr);
468 }
469
470 Ok(Expr::Column(contextdb_parser::ast::ColumnRef {
471 table: None,
472 column: start_alias.clone(),
473 }))
474}
475
476fn find_graph_start_expr(expr: &Expr, start_alias: &str) -> Option<Expr> {
477 match expr {
478 Expr::BinaryOp {
479 left,
480 op: BinOp::Eq,
481 right,
482 } => {
483 if is_graph_start_id_ref(left, start_alias) {
484 Some((**right).clone())
485 } else if is_graph_start_id_ref(right, start_alias) {
486 Some((**left).clone())
487 } else {
488 None
489 }
490 }
491 Expr::BinaryOp { left, right, .. } => find_graph_start_expr(left, start_alias)
492 .or_else(|| find_graph_start_expr(right, start_alias)),
493 Expr::UnaryOp { operand, .. } => find_graph_start_expr(operand, start_alias),
494 _ => None,
495 }
496}
497
498fn is_graph_start_id_ref(expr: &Expr, start_alias: &str) -> bool {
499 matches!(
500 expr,
501 Expr::Column(contextdb_parser::ast::ColumnRef {
502 table: Some(table),
503 column
504 }) if table == start_alias && column == "id"
505 )
506}