Skip to main content

nodedb_sql/ddl_ast/graph_parse/
mod.rs

1//! Typed parsing for the graph DSL (`GRAPH ...`, `MATCH ...`).
2//!
3//! The handler layer historically parsed graph statements with
4//! `upper.find("KEYWORD")` substring matching, which collapsed when
5//! a node id, label, or property value shadowed a DSL keyword. This
6//! module is the structural replacement: a quote- and brace-aware
7//! tokeniser feeds a variant-building parser that produces a typed
8//! [`NodedbStatement`]. Every graph DSL command flows through here
9//! before reaching a pgwire handler, so the handlers never touch
10//! raw SQL again.
11//!
12//! Values collected here are intentionally unvalidated — numeric
13//! bounds, absent-but-required fields, and engine-level caps are
14//! enforced at the pgwire boundary where the error response is
15//! formed. That keeps this module free of `pgwire` dependencies
16//! and out of the `nodedb` → `nodedb-sql` dependency edge.
17
18pub mod fusion_params;
19mod helpers;
20mod tokenizer;
21mod variants;
22
23pub use fusion_params::{
24    FusionKeywords, FusionParams, RAG_FUSION_KEYWORDS, SEARCH_FUSION_KEYWORDS,
25    parse_search_using_fusion,
26};
27
28use super::statement::NodedbStatement;
29
30/// Entry point: returns `Some` when `sql` starts with a graph DSL
31/// keyword (`GRAPH ...`, `MATCH ...`, `OPTIONAL MATCH ...`), else
32/// `None` so the main DDL parser can continue trying other cases.
33pub fn try_parse(sql: &str) -> Option<NodedbStatement> {
34    let trimmed = sql.trim();
35    let upper = trimmed.to_ascii_uppercase();
36
37    if upper.starts_with("MATCH ") || upper.starts_with("OPTIONAL MATCH ") {
38        return Some(NodedbStatement::MatchQuery {
39            raw_sql: trimmed.to_string(),
40        });
41    }
42
43    if !upper.starts_with("GRAPH ") {
44        return None;
45    }
46
47    let toks = tokenizer::tokenize(trimmed);
48
49    if upper.starts_with("GRAPH INSERT EDGE ") {
50        return variants::parse_insert_edge(&toks);
51    }
52    if upper.starts_with("GRAPH DELETE EDGE ") {
53        return variants::parse_delete_edge(&toks);
54    }
55    if upper.starts_with("GRAPH LABEL ") {
56        return variants::parse_set_labels(&toks, false);
57    }
58    if upper.starts_with("GRAPH UNLABEL ") {
59        return variants::parse_set_labels(&toks, true);
60    }
61    if upper.starts_with("GRAPH TRAVERSE ") {
62        return variants::parse_traverse(&toks);
63    }
64    if upper.starts_with("GRAPH NEIGHBORS ") {
65        return variants::parse_neighbors(&toks);
66    }
67    if upper.starts_with("GRAPH PATH ") {
68        return variants::parse_path(&toks);
69    }
70    if upper.starts_with("GRAPH ALGO ") {
71        return variants::parse_algo(&toks);
72    }
73    if upper.starts_with("GRAPH RAG FUSION ") {
74        return variants::parse_rag_fusion(&toks, trimmed);
75    }
76
77    None
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::ddl_ast::statement::{GraphDirection, GraphProperties};
84
85    #[test]
86    fn parse_graph_insert_edge_keyword_shaped_ids() {
87        let stmt =
88            try_parse("GRAPH INSERT EDGE IN 'myedges' FROM 'TO' TO 'FROM' TYPE 'LABEL'").unwrap();
89        match stmt {
90            NodedbStatement::GraphInsertEdge {
91                collection,
92                src,
93                dst,
94                label,
95                properties,
96            } => {
97                assert_eq!(collection, "myedges");
98                assert_eq!(src, "TO");
99                assert_eq!(dst, "FROM");
100                assert_eq!(label, "LABEL");
101                assert_eq!(properties, GraphProperties::None);
102            }
103            other => panic!("expected GraphInsertEdge, got {other:?}"),
104        }
105    }
106
107    #[test]
108    fn parse_graph_delete_edge_with_collection() {
109        let stmt = try_parse("GRAPH DELETE EDGE IN 'myedges' FROM 'a' TO 'b' TYPE 'l'").unwrap();
110        match stmt {
111            NodedbStatement::GraphDeleteEdge {
112                collection,
113                src,
114                dst,
115                label,
116            } => {
117                assert_eq!(collection, "myedges");
118                assert_eq!(src, "a");
119                assert_eq!(dst, "b");
120                assert_eq!(label, "l");
121            }
122            other => panic!("expected GraphDeleteEdge, got {other:?}"),
123        }
124    }
125
126    #[test]
127    fn parse_graph_insert_edge_missing_collection_returns_none() {
128        let result = try_parse("GRAPH INSERT EDGE FROM 'a' TO 'b' TYPE 'l'");
129        assert!(
130            result.is_none(),
131            "missing IN <collection> must not produce a statement"
132        );
133    }
134
135    #[test]
136    fn parse_graph_insert_edge_with_object_properties() {
137        let stmt = try_parse(
138            "GRAPH INSERT EDGE IN 'edges' FROM 'a' TO 'b' TYPE 'l' PROPERTIES { note: '} DEPTH 999' }",
139        )
140        .unwrap();
141        match stmt {
142            NodedbStatement::GraphInsertEdge {
143                collection,
144                properties,
145                ..
146            } => {
147                assert_eq!(collection, "edges");
148                match properties {
149                    GraphProperties::Object(s) => assert!(s.contains("} DEPTH 999")),
150                    other => panic!("expected Object properties, got {other:?}"),
151                }
152            }
153            other => panic!("expected GraphInsertEdge, got {other:?}"),
154        }
155    }
156
157    #[test]
158    fn parse_graph_traverse_keyword_substring_id() {
159        let stmt =
160            try_parse("GRAPH TRAVERSE FROM 'node_with_DEPTH_in_name' DEPTH 2 LABEL 'l'").unwrap();
161        match stmt {
162            NodedbStatement::GraphTraverse { start, depth, .. } => {
163                assert_eq!(start, "node_with_DEPTH_in_name");
164                assert_eq!(depth, 2);
165            }
166            other => panic!("expected GraphTraverse, got {other:?}"),
167        }
168    }
169
170    #[test]
171    fn parse_graph_path() {
172        let stmt = try_parse("GRAPH PATH FROM 'a' TO 'b' MAX_DEPTH 5 LABEL 'l'").unwrap();
173        match stmt {
174            NodedbStatement::GraphPath {
175                src,
176                dst,
177                max_depth,
178                edge_label,
179            } => {
180                assert_eq!(src, "a");
181                assert_eq!(dst, "b");
182                assert_eq!(max_depth, 5);
183                assert_eq!(edge_label.as_deref(), Some("l"));
184            }
185            other => panic!("expected GraphPath, got {other:?}"),
186        }
187    }
188
189    #[test]
190    fn parse_graph_labels_list() {
191        let stmt = try_parse("GRAPH LABEL 'alice' AS 'Person', 'User'").unwrap();
192        match stmt {
193            NodedbStatement::GraphSetLabels {
194                node_id,
195                labels,
196                remove,
197            } => {
198                assert_eq!(node_id, "alice");
199                assert_eq!(labels, vec!["Person".to_string(), "User".to_string()]);
200                assert!(!remove);
201            }
202            other => panic!("expected GraphSetLabels, got {other:?}"),
203        }
204    }
205
206    #[test]
207    fn parse_graph_algo_pagerank() {
208        let stmt = try_parse("GRAPH ALGO PAGERANK ON users ITERATIONS 5 DAMPING 0.85").unwrap();
209        match stmt {
210            NodedbStatement::GraphAlgo {
211                algorithm,
212                collection,
213                damping,
214                max_iterations,
215                ..
216            } => {
217                assert_eq!(algorithm, "PAGERANK");
218                assert_eq!(collection, "users");
219                assert_eq!(damping, Some(0.85));
220                assert_eq!(max_iterations, Some(5));
221            }
222            other => panic!("expected GraphAlgo, got {other:?}"),
223        }
224    }
225
226    #[test]
227    fn parse_match_query_captures_raw() {
228        let stmt = try_parse("MATCH (x)-[:l]->(y) RETURN x, y").unwrap();
229        match stmt {
230            NodedbStatement::MatchQuery { raw_sql } => {
231                assert!(raw_sql.starts_with("MATCH"));
232            }
233            other => panic!("expected MatchQuery, got {other:?}"),
234        }
235    }
236
237    #[test]
238    fn non_graph_returns_none() {
239        assert!(try_parse("SELECT * FROM users").is_none());
240        assert!(try_parse("CREATE COLLECTION users").is_none());
241    }
242
243    // ── GraphRagFusion parser tests ──────────────────────────────────────
244
245    #[test]
246    fn parse_rag_fusion_full_syntax() {
247        let stmt = try_parse(
248            "GRAPH RAG FUSION ON entities \
249             QUERY ARRAY[0.1, 0.2, 0.3] \
250             VECTOR_TOP_K 50 \
251             EXPANSION_DEPTH 2 \
252             EDGE_LABEL 'related_to' \
253             FINAL_TOP_K 10 \
254             RRF_K (60.0, 35.0)",
255        )
256        .unwrap();
257        match stmt {
258            NodedbStatement::GraphRagFusion { collection, params } => {
259                assert_eq!(collection, "entities");
260                let v = params.query_vector.expect("QUERY ARRAY parsed");
261                assert_eq!(v.len(), 3);
262                assert!((v[0] - 0.1f32).abs() < 1e-5);
263                assert_eq!(params.vector_top_k, Some(50));
264                assert_eq!(params.expansion_depth, Some(2));
265                assert_eq!(params.edge_label.as_deref(), Some("related_to"));
266                assert_eq!(params.final_top_k, Some(10));
267                let (k1, k2) = params.rrf_k.unwrap();
268                assert!((k1 - 60.0).abs() < 1e-10);
269                assert!((k2 - 35.0).abs() < 1e-10);
270            }
271            other => panic!("expected GraphRagFusion, got {other:?}"),
272        }
273    }
274
275    #[test]
276    fn parse_rag_fusion_minimal_defaults_to_none() {
277        let stmt = try_parse("GRAPH RAG FUSION ON mycol QUERY ARRAY[1.0, 0.0]").unwrap();
278        match stmt {
279            NodedbStatement::GraphRagFusion { collection, params } => {
280                assert_eq!(collection, "mycol");
281                assert!(params.query_vector.is_some());
282                assert_eq!(params.vector_top_k, None);
283                assert_eq!(params.expansion_depth, None);
284                assert_eq!(params.edge_label, None);
285                assert_eq!(params.final_top_k, None);
286                assert_eq!(params.rrf_k, None);
287                assert_eq!(params.vector_field, None);
288                assert_eq!(params.direction, None);
289                assert_eq!(params.max_visited, None);
290            }
291            other => panic!("expected GraphRagFusion, got {other:?}"),
292        }
293    }
294
295    #[test]
296    fn parse_rag_fusion_direction_and_max_visited() {
297        let stmt =
298            try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] DIRECTION both MAX_VISITED 500")
299                .unwrap();
300        match stmt {
301            NodedbStatement::GraphRagFusion { params, .. } => {
302                assert_eq!(params.direction, Some(GraphDirection::Both));
303                assert_eq!(params.max_visited, Some(500));
304            }
305            other => panic!("expected GraphRagFusion, got {other:?}"),
306        }
307    }
308
309    #[test]
310    fn parse_rag_fusion_vector_field_is_captured() {
311        let stmt =
312            try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] VECTOR_FIELD 'embedding'").unwrap();
313        match stmt {
314            NodedbStatement::GraphRagFusion { params, .. } => {
315                assert_eq!(params.vector_field.as_deref(), Some("embedding"));
316            }
317            other => panic!("expected GraphRagFusion, got {other:?}"),
318        }
319    }
320
321    #[test]
322    fn parse_rag_fusion_rrf_k_both_values_captured() {
323        let stmt = try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] RRF_K (1.0, 99.5)").unwrap();
324        match stmt {
325            NodedbStatement::GraphRagFusion { params, .. } => {
326                let (k1, k2) = params.rrf_k.expect("RRF_K must be parsed");
327                assert!((k1 - 1.0).abs() < 1e-10, "vector_k must be 1.0, got {k1}");
328                assert!((k2 - 99.5).abs() < 1e-10, "graph_k must be 99.5, got {k2}");
329            }
330            other => panic!("expected GraphRagFusion, got {other:?}"),
331        }
332    }
333
334    #[test]
335    fn parse_rag_fusion_missing_collection_returns_none() {
336        // ON keyword is absent — must not produce a statement.
337        let result = try_parse("GRAPH RAG FUSION QUERY ARRAY[0.1] VECTOR_TOP_K 5");
338        assert!(result.is_none(), "missing ON <collection> must return None");
339    }
340}