1use std::fmt::Write;
2
3use crate::{Predicate, QueryAst, QueryStep, TraverseDirection};
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub enum DrivingTable {
7 Nodes,
8 FtsNodes,
9 VecNodes,
10}
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct ExecutionHints {
14 pub recursion_limit: usize,
15 pub hard_limit: usize,
16}
17
18pub fn choose_driving_table(ast: &QueryAst) -> DrivingTable {
19 let has_deterministic_id_filter = ast.steps.iter().any(|step| {
20 matches!(
21 step,
22 QueryStep::Filter(Predicate::LogicalIdEq(_) | Predicate::SourceRefEq(_))
23 )
24 });
25
26 if has_deterministic_id_filter {
27 DrivingTable::Nodes
28 } else if ast
29 .steps
30 .iter()
31 .any(|step| matches!(step, QueryStep::VectorSearch { .. }))
32 {
33 DrivingTable::VecNodes
34 } else if ast
35 .steps
36 .iter()
37 .any(|step| matches!(step, QueryStep::TextSearch { .. }))
38 {
39 DrivingTable::FtsNodes
40 } else {
41 DrivingTable::Nodes
42 }
43}
44
45pub fn execution_hints(ast: &QueryAst) -> ExecutionHints {
46 let step_limit = ast
47 .steps
48 .iter()
49 .find_map(|step| {
50 if let QueryStep::Traverse { max_depth, .. } = step {
51 Some(*max_depth)
52 } else {
53 None
54 }
55 })
56 .unwrap_or(0);
57 let expansion_limit = ast
58 .expansions
59 .iter()
60 .map(|expansion| expansion.max_depth)
61 .max()
62 .unwrap_or(0);
63 let recursion_limit = step_limit.max(expansion_limit);
64
65 ExecutionHints {
66 recursion_limit,
67 hard_limit: ast.final_limit.unwrap_or(1000),
72 }
73}
74
75pub fn shape_signature(ast: &QueryAst) -> String {
76 let mut signature = String::new();
77 let _ = write!(&mut signature, "Root({})", ast.root_kind);
78
79 for step in &ast.steps {
80 match step {
81 QueryStep::VectorSearch { limit, .. } => {
82 let _ = write!(&mut signature, "-Vector(limit={limit})");
83 }
84 QueryStep::TextSearch { limit, .. } => {
85 let _ = write!(&mut signature, "-Text(limit={limit})");
86 }
87 QueryStep::Traverse {
88 direction,
89 label,
90 max_depth,
91 } => {
92 let dir = match direction {
93 TraverseDirection::In => "in",
94 TraverseDirection::Out => "out",
95 };
96 let _ = write!(
97 &mut signature,
98 "-Traverse(direction={dir},label={label},depth={max_depth})"
99 );
100 }
101 QueryStep::Filter(predicate) => match predicate {
102 Predicate::LogicalIdEq(_) => signature.push_str("-Filter(logical_id_eq)"),
103 Predicate::KindEq(_) => signature.push_str("-Filter(kind_eq)"),
104 Predicate::JsonPathEq { path, .. } => {
105 let _ = write!(&mut signature, "-Filter(json_eq:{path})");
106 }
107 Predicate::JsonPathCompare { path, op, .. } => {
108 let op = match op {
109 crate::ComparisonOp::Gt => "gt",
110 crate::ComparisonOp::Gte => "gte",
111 crate::ComparisonOp::Lt => "lt",
112 crate::ComparisonOp::Lte => "lte",
113 };
114 let _ = write!(&mut signature, "-Filter(json_cmp:{path}:{op})");
115 }
116 Predicate::SourceRefEq(_) => signature.push_str("-Filter(source_ref_eq)"),
117 Predicate::ContentRefNotNull => {
118 signature.push_str("-Filter(content_ref_not_null)");
119 }
120 Predicate::ContentRefEq(_) => signature.push_str("-Filter(content_ref_eq)"),
121 },
122 }
123 }
124
125 for expansion in &ast.expansions {
126 let dir = match expansion.direction {
127 TraverseDirection::In => "in",
128 TraverseDirection::Out => "out",
129 };
130 let _ = write!(
131 &mut signature,
132 "-Expand(slot={},direction={dir},label={},depth={})",
133 expansion.slot, expansion.label, expansion.max_depth
134 );
135 }
136
137 if let Some(limit) = ast.final_limit {
138 let _ = write!(&mut signature, "-Limit({limit})");
139 }
140
141 signature
142}
143
144#[cfg(test)]
145mod tests {
146 use crate::{DrivingTable, QueryBuilder, TraverseDirection};
147
148 use super::{choose_driving_table, execution_hints};
149
150 #[test]
151 fn deterministic_filter_overrides_vector_driver() {
152 let ast = QueryBuilder::nodes("Meeting")
153 .vector_search("budget", 5)
154 .filter_logical_id_eq("meeting-123")
155 .into_ast();
156
157 assert_eq!(choose_driving_table(&ast), DrivingTable::Nodes);
158 }
159
160 #[test]
161 fn hard_limit_honors_user_specified_limit_below_default() {
162 let ast = QueryBuilder::nodes("Meeting")
163 .traverse(TraverseDirection::Out, "HAS_TASK", 3)
164 .limit(10)
165 .into_ast();
166
167 let hints = execution_hints(&ast);
168 assert_eq!(
169 hints.hard_limit, 10,
170 "hard_limit must honor user's final_limit"
171 );
172 }
173
174 #[test]
175 fn hard_limit_defaults_to_1000_when_no_limit_set() {
176 let ast = QueryBuilder::nodes("Meeting")
177 .traverse(TraverseDirection::Out, "HAS_TASK", 3)
178 .into_ast();
179
180 let hints = execution_hints(&ast);
181 assert_eq!(hints.hard_limit, 1000, "hard_limit defaults to 1000");
182 }
183}