1use std::fmt::Write;
2
3use crate::{Predicate, QueryAst, QueryStep, TraverseDirection};
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub enum DrivingTable {
7 Nodes,
8 FtsNodes,
9 VecNodes,
10}
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct ExecutionHints {
14 pub recursion_limit: usize,
15 pub hard_limit: usize,
16}
17
18pub fn choose_driving_table(ast: &QueryAst) -> DrivingTable {
19 let has_deterministic_id_filter = ast.steps.iter().any(|step| {
20 matches!(
21 step,
22 QueryStep::Filter(Predicate::LogicalIdEq(_) | Predicate::SourceRefEq(_))
23 )
24 });
25
26 if has_deterministic_id_filter {
27 DrivingTable::Nodes
28 } else if ast
29 .steps
30 .iter()
31 .any(|step| matches!(step, QueryStep::VectorSearch { .. }))
32 {
33 DrivingTable::VecNodes
34 } else if ast
35 .steps
36 .iter()
37 .any(|step| matches!(step, QueryStep::TextSearch { .. }))
38 {
39 DrivingTable::FtsNodes
40 } else {
41 DrivingTable::Nodes
42 }
43}
44
45pub fn execution_hints(ast: &QueryAst) -> ExecutionHints {
46 let step_limit = ast
47 .steps
48 .iter()
49 .find_map(|step| {
50 if let QueryStep::Traverse { max_depth, .. } = step {
51 Some(*max_depth)
52 } else {
53 None
54 }
55 })
56 .unwrap_or(0);
57 let expansion_limit = ast
58 .expansions
59 .iter()
60 .map(|expansion| expansion.max_depth)
61 .max()
62 .unwrap_or(0);
63 let recursion_limit = step_limit.max(expansion_limit);
64
65 ExecutionHints {
66 recursion_limit,
67 hard_limit: ast.final_limit.unwrap_or(1000),
72 }
73}
74
75pub fn shape_signature(ast: &QueryAst) -> String {
76 let mut signature = String::new();
77 let _ = write!(&mut signature, "Root({})", ast.root_kind);
78
79 for step in &ast.steps {
80 match step {
81 QueryStep::Search { limit, .. } => {
82 let _ = write!(&mut signature, "-Search(limit={limit})");
83 }
84 QueryStep::VectorSearch { limit, .. } => {
85 let _ = write!(&mut signature, "-Vector(limit={limit})");
86 }
87 QueryStep::TextSearch { limit, .. } => {
88 let _ = write!(&mut signature, "-Text(limit={limit})");
89 }
90 QueryStep::Traverse {
91 direction,
92 label,
93 max_depth,
94 } => {
95 let dir = match direction {
96 TraverseDirection::In => "in",
97 TraverseDirection::Out => "out",
98 };
99 let _ = write!(
100 &mut signature,
101 "-Traverse(direction={dir},label={label},depth={max_depth})"
102 );
103 }
104 QueryStep::Filter(predicate) => match predicate {
105 Predicate::LogicalIdEq(_) => signature.push_str("-Filter(logical_id_eq)"),
106 Predicate::KindEq(_) => signature.push_str("-Filter(kind_eq)"),
107 Predicate::JsonPathEq { path, .. } => {
108 let _ = write!(&mut signature, "-Filter(json_eq:{path})");
109 }
110 Predicate::JsonPathCompare { path, op, .. } => {
111 let op = match op {
112 crate::ComparisonOp::Gt => "gt",
113 crate::ComparisonOp::Gte => "gte",
114 crate::ComparisonOp::Lt => "lt",
115 crate::ComparisonOp::Lte => "lte",
116 };
117 let _ = write!(&mut signature, "-Filter(json_cmp:{path}:{op})");
118 }
119 Predicate::SourceRefEq(_) => signature.push_str("-Filter(source_ref_eq)"),
120 Predicate::ContentRefNotNull => {
121 signature.push_str("-Filter(content_ref_not_null)");
122 }
123 Predicate::ContentRefEq(_) => signature.push_str("-Filter(content_ref_eq)"),
124 },
125 }
126 }
127
128 for expansion in &ast.expansions {
129 let dir = match expansion.direction {
130 TraverseDirection::In => "in",
131 TraverseDirection::Out => "out",
132 };
133 let _ = write!(
134 &mut signature,
135 "-Expand(slot={},direction={dir},label={},depth={})",
136 expansion.slot, expansion.label, expansion.max_depth
137 );
138 }
139
140 if let Some(limit) = ast.final_limit {
141 let _ = write!(&mut signature, "-Limit({limit})");
142 }
143
144 signature
145}
146
147#[cfg(test)]
148mod tests {
149 use crate::{DrivingTable, QueryBuilder, TraverseDirection};
150
151 use super::{choose_driving_table, execution_hints};
152
153 #[test]
154 fn deterministic_filter_overrides_vector_driver() {
155 let ast = QueryBuilder::nodes("Meeting")
156 .vector_search("budget", 5)
157 .filter_logical_id_eq("meeting-123")
158 .into_ast();
159
160 assert_eq!(choose_driving_table(&ast), DrivingTable::Nodes);
161 }
162
163 #[test]
164 fn hard_limit_honors_user_specified_limit_below_default() {
165 let ast = QueryBuilder::nodes("Meeting")
166 .traverse(TraverseDirection::Out, "HAS_TASK", 3)
167 .limit(10)
168 .into_ast();
169
170 let hints = execution_hints(&ast);
171 assert_eq!(
172 hints.hard_limit, 10,
173 "hard_limit must honor user's final_limit"
174 );
175 }
176
177 #[test]
178 fn hard_limit_defaults_to_1000_when_no_limit_set() {
179 let ast = QueryBuilder::nodes("Meeting")
180 .traverse(TraverseDirection::Out, "HAS_TASK", 3)
181 .into_ast();
182
183 let hints = execution_hints(&ast);
184 assert_eq!(hints.hard_limit, 1000, "hard_limit defaults to 1000");
185 }
186}