Skip to main content

nodedb_sql/ddl_ast/graph_parse/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Typed parsing for the graph DSL (`GRAPH ...`, `MATCH ...`).
4//!
5//! The handler layer historically parsed graph statements with
6//! `upper.find("KEYWORD")` substring matching, which collapsed when
7//! a node id, label, or property value shadowed a DSL keyword. This
8//! module is the structural replacement: a quote- and brace-aware
9//! tokeniser feeds a variant-building parser that produces a typed
10//! [`NodedbStatement`]. Every graph DSL command flows through here
11//! before reaching a pgwire handler, so the handlers never touch
12//! raw SQL again.
13//!
14//! Values collected here are intentionally unvalidated — numeric
15//! bounds, absent-but-required fields, and engine-level caps are
16//! enforced at the pgwire boundary where the error response is
17//! formed. That keeps this module free of `pgwire` dependencies
18//! and out of the `nodedb` → `nodedb-sql` dependency edge.
19
20pub mod fusion_params;
21mod helpers;
22mod tokenizer;
23mod variants;
24
25pub use fusion_params::{
26    FusionKeywords, FusionParams, RAG_FUSION_KEYWORDS, SEARCH_FUSION_KEYWORDS,
27    parse_search_using_fusion,
28};
29
30use super::statement::NodedbStatement;
31
32/// Entry point: returns `Some` when `sql` starts with a graph DSL
33/// keyword (`GRAPH ...`, `MATCH ...`, `OPTIONAL MATCH ...`), else
34/// `None` so the main DDL parser can continue trying other cases.
35pub fn try_parse(sql: &str) -> Option<NodedbStatement> {
36    let trimmed = sql.trim();
37    let upper = trimmed.to_ascii_uppercase();
38
39    if upper.starts_with("MATCH ") || upper.starts_with("OPTIONAL MATCH ") {
40        return Some(NodedbStatement::MatchQuery {
41            body: trimmed.to_string(),
42        });
43    }
44
45    if !upper.starts_with("GRAPH ") {
46        return None;
47    }
48
49    let toks = tokenizer::tokenize(trimmed);
50
51    if upper.starts_with("GRAPH INSERT EDGE ") {
52        return variants::parse_insert_edge(&toks);
53    }
54    if upper.starts_with("GRAPH DELETE EDGE ") {
55        return variants::parse_delete_edge(&toks);
56    }
57    if upper.starts_with("GRAPH LABEL ") {
58        return variants::parse_set_labels(&toks, false);
59    }
60    if upper.starts_with("GRAPH UNLABEL ") {
61        return variants::parse_set_labels(&toks, true);
62    }
63    if upper.starts_with("GRAPH TRAVERSE ") {
64        return variants::parse_traverse(&toks);
65    }
66    if upper.starts_with("GRAPH NEIGHBORS ") {
67        return variants::parse_neighbors(&toks);
68    }
69    if upper.starts_with("GRAPH PATH ") {
70        return variants::parse_path(&toks);
71    }
72    if upper.starts_with("GRAPH ALGO ") {
73        return variants::parse_algo(&toks);
74    }
75    if upper.starts_with("GRAPH RAG FUSION ") {
76        return variants::parse_rag_fusion(&toks, trimmed);
77    }
78
79    None
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use crate::ddl_ast::statement::{GraphDirection, GraphProperties};
86
87    #[test]
88    fn parse_graph_insert_edge_keyword_shaped_ids() {
89        let stmt =
90            try_parse("GRAPH INSERT EDGE IN 'myedges' FROM 'TO' TO 'FROM' TYPE 'LABEL'").unwrap();
91        match stmt {
92            NodedbStatement::GraphInsertEdge {
93                collection,
94                src,
95                dst,
96                label,
97                properties,
98            } => {
99                assert_eq!(collection, "myedges");
100                assert_eq!(src, "TO");
101                assert_eq!(dst, "FROM");
102                assert_eq!(label, "LABEL");
103                assert_eq!(properties, GraphProperties::None);
104            }
105            other => panic!("expected GraphInsertEdge, got {other:?}"),
106        }
107    }
108
109    #[test]
110    fn parse_graph_delete_edge_with_collection() {
111        let stmt = try_parse("GRAPH DELETE EDGE IN 'myedges' FROM 'a' TO 'b' TYPE 'l'").unwrap();
112        match stmt {
113            NodedbStatement::GraphDeleteEdge {
114                collection,
115                src,
116                dst,
117                label,
118            } => {
119                assert_eq!(collection, "myedges");
120                assert_eq!(src, "a");
121                assert_eq!(dst, "b");
122                assert_eq!(label, "l");
123            }
124            other => panic!("expected GraphDeleteEdge, got {other:?}"),
125        }
126    }
127
128    #[test]
129    fn parse_graph_insert_edge_missing_collection_returns_none() {
130        let result = try_parse("GRAPH INSERT EDGE FROM 'a' TO 'b' TYPE 'l'");
131        assert!(
132            result.is_none(),
133            "missing IN <collection> must not produce a statement"
134        );
135    }
136
137    #[test]
138    fn parse_graph_insert_edge_with_object_properties() {
139        let stmt = try_parse(
140            "GRAPH INSERT EDGE IN 'edges' FROM 'a' TO 'b' TYPE 'l' PROPERTIES { note: '} DEPTH 999' }",
141        )
142        .unwrap();
143        match stmt {
144            NodedbStatement::GraphInsertEdge {
145                collection,
146                properties,
147                ..
148            } => {
149                assert_eq!(collection, "edges");
150                match properties {
151                    GraphProperties::Object(s) => assert!(s.contains("} DEPTH 999")),
152                    other => panic!("expected Object properties, got {other:?}"),
153                }
154            }
155            other => panic!("expected GraphInsertEdge, got {other:?}"),
156        }
157    }
158
159    #[test]
160    fn parse_graph_traverse_keyword_substring_id() {
161        let stmt =
162            try_parse("GRAPH TRAVERSE FROM 'node_with_DEPTH_in_name' DEPTH 2 LABEL 'l'").unwrap();
163        match stmt {
164            NodedbStatement::GraphTraverse { start, depth, .. } => {
165                assert_eq!(start, "node_with_DEPTH_in_name");
166                assert_eq!(depth, 2);
167            }
168            other => panic!("expected GraphTraverse, got {other:?}"),
169        }
170    }
171
172    #[test]
173    fn parse_graph_path() {
174        let stmt = try_parse("GRAPH PATH FROM 'a' TO 'b' MAX_DEPTH 5 LABEL 'l'").unwrap();
175        match stmt {
176            NodedbStatement::GraphPath {
177                src,
178                dst,
179                max_depth,
180                edge_label,
181            } => {
182                assert_eq!(src, "a");
183                assert_eq!(dst, "b");
184                assert_eq!(max_depth, 5);
185                assert_eq!(edge_label.as_deref(), Some("l"));
186            }
187            other => panic!("expected GraphPath, got {other:?}"),
188        }
189    }
190
191    #[test]
192    fn parse_graph_labels_list() {
193        let stmt = try_parse("GRAPH LABEL 'alice' AS 'Person', 'User'").unwrap();
194        match stmt {
195            NodedbStatement::GraphSetLabels {
196                node_id,
197                labels,
198                remove,
199            } => {
200                assert_eq!(node_id, "alice");
201                assert_eq!(labels, vec!["Person".to_string(), "User".to_string()]);
202                assert!(!remove);
203            }
204            other => panic!("expected GraphSetLabels, got {other:?}"),
205        }
206    }
207
208    #[test]
209    fn parse_graph_algo_pagerank() {
210        let stmt = try_parse("GRAPH ALGO PAGERANK ON users ITERATIONS 5 DAMPING 0.85").unwrap();
211        match stmt {
212            NodedbStatement::GraphAlgo {
213                algorithm,
214                collection,
215                damping,
216                max_iterations,
217                ..
218            } => {
219                assert_eq!(algorithm, "PAGERANK");
220                assert_eq!(collection, "users");
221                assert_eq!(damping, Some(0.85));
222                assert_eq!(max_iterations, Some(5));
223            }
224            other => panic!("expected GraphAlgo, got {other:?}"),
225        }
226    }
227
228    #[test]
229    fn parse_match_query_captures_raw() {
230        let stmt = try_parse("MATCH (x)-[:l]->(y) RETURN x, y").unwrap();
231        match stmt {
232            NodedbStatement::MatchQuery { body } => {
233                assert!(body.starts_with("MATCH"));
234            }
235            other => panic!("expected MatchQuery, got {other:?}"),
236        }
237    }
238
239    #[test]
240    fn non_graph_returns_none() {
241        assert!(try_parse("SELECT * FROM users").is_none());
242        assert!(try_parse("CREATE COLLECTION users").is_none());
243    }
244
245    // ── GraphRagFusion parser tests ──────────────────────────────────────
246
247    #[test]
248    fn parse_rag_fusion_full_syntax() {
249        let stmt = try_parse(
250            "GRAPH RAG FUSION ON entities \
251             QUERY ARRAY[0.1, 0.2, 0.3] \
252             VECTOR_TOP_K 50 \
253             EXPANSION_DEPTH 2 \
254             EDGE_LABEL 'related_to' \
255             FINAL_TOP_K 10 \
256             RRF_K (60.0, 35.0)",
257        )
258        .unwrap();
259        match stmt {
260            NodedbStatement::GraphRagFusion { collection, params } => {
261                assert_eq!(collection, "entities");
262                let v = params.query_vector.expect("QUERY ARRAY parsed");
263                assert_eq!(v.len(), 3);
264                assert!((v[0] - 0.1f32).abs() < 1e-5);
265                assert_eq!(params.vector_top_k, Some(50));
266                assert_eq!(params.expansion_depth, Some(2));
267                assert_eq!(params.edge_label.as_deref(), Some("related_to"));
268                assert_eq!(params.final_top_k, Some(10));
269                let (k1, k2) = params.rrf_k.unwrap();
270                assert!((k1 - 60.0).abs() < 1e-10);
271                assert!((k2 - 35.0).abs() < 1e-10);
272            }
273            other => panic!("expected GraphRagFusion, got {other:?}"),
274        }
275    }
276
277    #[test]
278    fn parse_rag_fusion_minimal_defaults_to_none() {
279        let stmt = try_parse("GRAPH RAG FUSION ON mycol QUERY ARRAY[1.0, 0.0]").unwrap();
280        match stmt {
281            NodedbStatement::GraphRagFusion { collection, params } => {
282                assert_eq!(collection, "mycol");
283                assert!(params.query_vector.is_some());
284                assert_eq!(params.vector_top_k, None);
285                assert_eq!(params.expansion_depth, None);
286                assert_eq!(params.edge_label, None);
287                assert_eq!(params.final_top_k, None);
288                assert_eq!(params.rrf_k, None);
289                assert_eq!(params.vector_field, None);
290                assert_eq!(params.direction, None);
291                assert_eq!(params.max_visited, None);
292            }
293            other => panic!("expected GraphRagFusion, got {other:?}"),
294        }
295    }
296
297    #[test]
298    fn parse_rag_fusion_direction_and_max_visited() {
299        let stmt =
300            try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] DIRECTION both MAX_VISITED 500")
301                .unwrap();
302        match stmt {
303            NodedbStatement::GraphRagFusion { params, .. } => {
304                assert_eq!(params.direction, Some(GraphDirection::Both));
305                assert_eq!(params.max_visited, Some(500));
306            }
307            other => panic!("expected GraphRagFusion, got {other:?}"),
308        }
309    }
310
311    #[test]
312    fn parse_rag_fusion_vector_field_is_captured() {
313        let stmt =
314            try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] VECTOR_FIELD 'embedding'").unwrap();
315        match stmt {
316            NodedbStatement::GraphRagFusion { params, .. } => {
317                assert_eq!(params.vector_field.as_deref(), Some("embedding"));
318            }
319            other => panic!("expected GraphRagFusion, got {other:?}"),
320        }
321    }
322
323    #[test]
324    fn parse_rag_fusion_rrf_k_both_values_captured() {
325        let stmt = try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] RRF_K (1.0, 99.5)").unwrap();
326        match stmt {
327            NodedbStatement::GraphRagFusion { params, .. } => {
328                let (k1, k2) = params.rrf_k.expect("RRF_K must be parsed");
329                assert!((k1 - 1.0).abs() < 1e-10, "vector_k must be 1.0, got {k1}");
330                assert!((k2 - 99.5).abs() < 1e-10, "graph_k must be 99.5, got {k2}");
331            }
332            other => panic!("expected GraphRagFusion, got {other:?}"),
333        }
334    }
335
336    #[test]
337    fn parse_rag_fusion_missing_collection_returns_none() {
338        // ON keyword is absent — must not produce a statement.
339        let result = try_parse("GRAPH RAG FUSION QUERY ARRAY[0.1] VECTOR_TOP_K 5");
340        assert!(result.is_none(), "missing ON <collection> must return None");
341    }
342}