1use std::fmt::Write;
2
3use crate::{Predicate, QueryAst, QueryStep, TraverseDirection};
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub enum DrivingTable {
7 Nodes,
8 FtsNodes,
9 VecNodes,
10}
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct ExecutionHints {
14 pub recursion_limit: usize,
15 pub hard_limit: usize,
16}
17
18pub fn choose_driving_table(ast: &QueryAst) -> DrivingTable {
19 let has_deterministic_id_filter = ast.steps.iter().any(|step| {
20 matches!(
21 step,
22 QueryStep::Filter(Predicate::LogicalIdEq(_) | Predicate::SourceRefEq(_))
23 )
24 });
25
26 if has_deterministic_id_filter {
27 DrivingTable::Nodes
28 } else if ast
29 .steps
30 .iter()
31 .any(|step| matches!(step, QueryStep::VectorSearch { .. }))
32 {
33 DrivingTable::VecNodes
34 } else if ast
35 .steps
36 .iter()
37 .any(|step| matches!(step, QueryStep::TextSearch { .. }))
38 {
39 DrivingTable::FtsNodes
40 } else {
41 DrivingTable::Nodes
42 }
43}
44
45pub fn execution_hints(ast: &QueryAst) -> ExecutionHints {
46 let step_limit = ast
47 .steps
48 .iter()
49 .find_map(|step| {
50 if let QueryStep::Traverse { max_depth, .. } = step {
51 Some(*max_depth)
52 } else {
53 None
54 }
55 })
56 .unwrap_or(0);
57 let expansion_limit = ast
58 .expansions
59 .iter()
60 .map(|expansion| expansion.max_depth)
61 .max()
62 .unwrap_or(0);
63 let recursion_limit = step_limit.max(expansion_limit);
64
65 ExecutionHints {
66 recursion_limit,
67 hard_limit: ast.final_limit.unwrap_or(1000),
72 }
73}
74
75pub fn shape_signature(ast: &QueryAst) -> String {
76 let mut signature = String::new();
77 let _ = write!(&mut signature, "Root({})", ast.root_kind);
78
79 for step in &ast.steps {
80 match step {
81 QueryStep::Search { limit, .. } => {
82 let _ = write!(&mut signature, "-Search(limit={limit})");
83 }
84 QueryStep::VectorSearch { limit, .. } => {
85 let _ = write!(&mut signature, "-Vector(limit={limit})");
86 }
87 QueryStep::TextSearch { limit, .. } => {
88 let _ = write!(&mut signature, "-Text(limit={limit})");
89 }
90 QueryStep::Traverse {
91 direction,
92 label,
93 max_depth,
94 filter: _,
95 } => {
96 let dir = match direction {
97 TraverseDirection::In => "in",
98 TraverseDirection::Out => "out",
99 };
100 let _ = write!(
101 &mut signature,
102 "-Traverse(direction={dir},label={label},depth={max_depth})"
103 );
104 }
105 QueryStep::Filter(predicate) => match predicate {
106 Predicate::LogicalIdEq(_) => signature.push_str("-Filter(logical_id_eq)"),
107 Predicate::KindEq(_) => signature.push_str("-Filter(kind_eq)"),
108 Predicate::JsonPathEq { path, .. } => {
109 let _ = write!(&mut signature, "-Filter(json_eq:{path})");
110 }
111 Predicate::JsonPathCompare { path, op, .. } => {
112 let op = match op {
113 crate::ComparisonOp::Gt => "gt",
114 crate::ComparisonOp::Gte => "gte",
115 crate::ComparisonOp::Lt => "lt",
116 crate::ComparisonOp::Lte => "lte",
117 };
118 let _ = write!(&mut signature, "-Filter(json_cmp:{path}:{op})");
119 }
120 Predicate::SourceRefEq(_) => signature.push_str("-Filter(source_ref_eq)"),
121 Predicate::ContentRefNotNull => {
122 signature.push_str("-Filter(content_ref_not_null)");
123 }
124 Predicate::ContentRefEq(_) => signature.push_str("-Filter(content_ref_eq)"),
125 Predicate::JsonPathFusedEq { path, .. } => {
126 let _ = write!(&mut signature, "-Filter(json_fused_eq:{path})");
127 }
128 Predicate::JsonPathFusedTimestampCmp { path, op, .. } => {
129 let op = match op {
130 crate::ComparisonOp::Gt => "gt",
131 crate::ComparisonOp::Gte => "gte",
132 crate::ComparisonOp::Lt => "lt",
133 crate::ComparisonOp::Lte => "lte",
134 };
135 let _ = write!(&mut signature, "-Filter(json_fused_ts_cmp:{path}:{op})");
136 }
137 },
138 }
139 }
140
141 for expansion in &ast.expansions {
142 let dir = match expansion.direction {
143 TraverseDirection::In => "in",
144 TraverseDirection::Out => "out",
145 };
146 let _ = write!(
147 &mut signature,
148 "-Expand(slot={},direction={dir},label={},depth={})",
149 expansion.slot, expansion.label, expansion.max_depth
150 );
151 }
152
153 if let Some(limit) = ast.final_limit {
154 let _ = write!(&mut signature, "-Limit({limit})");
155 }
156
157 signature
158}
159
160#[cfg(test)]
161mod tests {
162 use crate::{DrivingTable, QueryBuilder, TraverseDirection};
163
164 use super::{choose_driving_table, execution_hints};
165
166 #[test]
167 fn deterministic_filter_overrides_vector_driver() {
168 let ast = QueryBuilder::nodes("Meeting")
169 .vector_search("budget", 5)
170 .filter_logical_id_eq("meeting-123")
171 .into_ast();
172
173 assert_eq!(choose_driving_table(&ast), DrivingTable::Nodes);
174 }
175
176 #[test]
177 fn hard_limit_honors_user_specified_limit_below_default() {
178 let ast = QueryBuilder::nodes("Meeting")
179 .traverse(TraverseDirection::Out, "HAS_TASK", 3)
180 .limit(10)
181 .into_ast();
182
183 let hints = execution_hints(&ast);
184 assert_eq!(
185 hints.hard_limit, 10,
186 "hard_limit must honor user's final_limit"
187 );
188 }
189
190 #[test]
191 fn hard_limit_defaults_to_1000_when_no_limit_set() {
192 let ast = QueryBuilder::nodes("Meeting")
193 .traverse(TraverseDirection::Out, "HAS_TASK", 3)
194 .into_ast();
195
196 let hints = execution_hints(&ast);
197 assert_eq!(hints.hard_limit, 1000, "hard_limit defaults to 1000");
198 }
199}