Skip to main content

fathomdb_query/
plan.rs

1use std::fmt::Write;
2
3use crate::{Predicate, QueryAst, QueryStep, TraverseDirection};
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq)]
6pub enum DrivingTable {
7    Nodes,
8    FtsNodes,
9    VecNodes,
10}
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct ExecutionHints {
14    pub recursion_limit: usize,
15    pub hard_limit: usize,
16}
17
18pub fn choose_driving_table(ast: &QueryAst) -> DrivingTable {
19    let has_deterministic_id_filter = ast.steps.iter().any(|step| {
20        matches!(
21            step,
22            QueryStep::Filter(Predicate::LogicalIdEq(_) | Predicate::SourceRefEq(_))
23        )
24    });
25
26    if has_deterministic_id_filter {
27        DrivingTable::Nodes
28    } else if ast
29        .steps
30        .iter()
31        .any(|step| matches!(step, QueryStep::VectorSearch { .. }))
32    {
33        DrivingTable::VecNodes
34    } else if ast
35        .steps
36        .iter()
37        .any(|step| matches!(step, QueryStep::TextSearch { .. }))
38    {
39        DrivingTable::FtsNodes
40    } else {
41        DrivingTable::Nodes
42    }
43}
44
45pub fn execution_hints(ast: &QueryAst) -> ExecutionHints {
46    let step_limit = ast
47        .steps
48        .iter()
49        .find_map(|step| {
50            if let QueryStep::Traverse { max_depth, .. } = step {
51                Some(*max_depth)
52            } else {
53                None
54            }
55        })
56        .unwrap_or(0);
57    let expansion_limit = ast
58        .expansions
59        .iter()
60        .map(|expansion| expansion.max_depth)
61        .max()
62        .unwrap_or(0);
63    let recursion_limit = step_limit.max(expansion_limit);
64
65    ExecutionHints {
66        recursion_limit,
67        // FIX(review): was .max(1000) — always produced >= 1000, ignoring user's final_limit.
68        // Options considered: (A) use final_limit directly with default, (B) .min(MAX) ceiling,
69        // (C) decouple from final_limit. Chose (A): the CTE LIMIT should honor the user's
70        // requested limit; the depth bound (compile.rs:177) already constrains recursion.
71        hard_limit: ast.final_limit.unwrap_or(1000),
72    }
73}
74
75pub fn shape_signature(ast: &QueryAst) -> String {
76    let mut signature = String::new();
77    let _ = write!(&mut signature, "Root({})", ast.root_kind);
78
79    for step in &ast.steps {
80        match step {
81            QueryStep::VectorSearch { limit, .. } => {
82                let _ = write!(&mut signature, "-Vector(limit={limit})");
83            }
84            QueryStep::TextSearch { limit, .. } => {
85                let _ = write!(&mut signature, "-Text(limit={limit})");
86            }
87            QueryStep::Traverse {
88                direction,
89                label,
90                max_depth,
91            } => {
92                let dir = match direction {
93                    TraverseDirection::In => "in",
94                    TraverseDirection::Out => "out",
95                };
96                let _ = write!(
97                    &mut signature,
98                    "-Traverse(direction={dir},label={label},depth={max_depth})"
99                );
100            }
101            QueryStep::Filter(predicate) => match predicate {
102                Predicate::LogicalIdEq(_) => signature.push_str("-Filter(logical_id_eq)"),
103                Predicate::KindEq(_) => signature.push_str("-Filter(kind_eq)"),
104                Predicate::JsonPathEq { path, .. } => {
105                    let _ = write!(&mut signature, "-Filter(json_eq:{path})");
106                }
107                Predicate::JsonPathCompare { path, op, .. } => {
108                    let op = match op {
109                        crate::ComparisonOp::Gt => "gt",
110                        crate::ComparisonOp::Gte => "gte",
111                        crate::ComparisonOp::Lt => "lt",
112                        crate::ComparisonOp::Lte => "lte",
113                    };
114                    let _ = write!(&mut signature, "-Filter(json_cmp:{path}:{op})");
115                }
116                Predicate::SourceRefEq(_) => signature.push_str("-Filter(source_ref_eq)"),
117                Predicate::ContentRefNotNull => {
118                    signature.push_str("-Filter(content_ref_not_null)");
119                }
120                Predicate::ContentRefEq(_) => signature.push_str("-Filter(content_ref_eq)"),
121            },
122        }
123    }
124
125    for expansion in &ast.expansions {
126        let dir = match expansion.direction {
127            TraverseDirection::In => "in",
128            TraverseDirection::Out => "out",
129        };
130        let _ = write!(
131            &mut signature,
132            "-Expand(slot={},direction={dir},label={},depth={})",
133            expansion.slot, expansion.label, expansion.max_depth
134        );
135    }
136
137    if let Some(limit) = ast.final_limit {
138        let _ = write!(&mut signature, "-Limit({limit})");
139    }
140
141    signature
142}
143
144#[cfg(test)]
145mod tests {
146    use crate::{DrivingTable, QueryBuilder, TraverseDirection};
147
148    use super::{choose_driving_table, execution_hints};
149
150    #[test]
151    fn deterministic_filter_overrides_vector_driver() {
152        let ast = QueryBuilder::nodes("Meeting")
153            .vector_search("budget", 5)
154            .filter_logical_id_eq("meeting-123")
155            .into_ast();
156
157        assert_eq!(choose_driving_table(&ast), DrivingTable::Nodes);
158    }
159
160    #[test]
161    fn hard_limit_honors_user_specified_limit_below_default() {
162        let ast = QueryBuilder::nodes("Meeting")
163            .traverse(TraverseDirection::Out, "HAS_TASK", 3)
164            .limit(10)
165            .into_ast();
166
167        let hints = execution_hints(&ast);
168        assert_eq!(
169            hints.hard_limit, 10,
170            "hard_limit must honor user's final_limit"
171        );
172    }
173
174    #[test]
175    fn hard_limit_defaults_to_1000_when_no_limit_set() {
176        let ast = QueryBuilder::nodes("Meeting")
177            .traverse(TraverseDirection::Out, "HAS_TASK", 3)
178            .into_ast();
179
180        let hints = execution_hints(&ast);
181        assert_eq!(hints.hard_limit, 1000, "hard_limit defaults to 1000");
182    }
183}