nodedb_sql/ddl_ast/graph_parse/
mod.rs1pub mod fusion_params;
21mod helpers;
22mod tokenizer;
23mod variants;
24
25pub use fusion_params::{
26 FusionKeywords, FusionParams, RAG_FUSION_KEYWORDS, SEARCH_FUSION_KEYWORDS,
27 parse_search_using_fusion,
28};
29
30use super::statement::{GraphStmt, NodedbStatement};
31
32pub fn try_parse(sql: &str) -> Option<NodedbStatement> {
36 let trimmed = sql.trim();
37 let upper = trimmed.to_ascii_uppercase();
38
39 if upper.starts_with("MATCH ") || upper.starts_with("OPTIONAL MATCH ") {
40 return Some(NodedbStatement::Graph(GraphStmt::MatchQuery {
41 body: trimmed.to_string(),
42 }));
43 }
44
45 if !upper.starts_with("GRAPH ") {
46 return None;
47 }
48
49 let toks = tokenizer::tokenize(trimmed);
50
51 if upper.starts_with("GRAPH INSERT EDGE ") {
52 return variants::parse_insert_edge(&toks);
53 }
54 if upper.starts_with("GRAPH DELETE EDGE ") {
55 return variants::parse_delete_edge(&toks);
56 }
57 if upper.starts_with("GRAPH LABEL ") {
58 return variants::parse_set_labels(&toks, false);
59 }
60 if upper.starts_with("GRAPH UNLABEL ") {
61 return variants::parse_set_labels(&toks, true);
62 }
63 if upper.starts_with("GRAPH TRAVERSE ") {
64 return variants::parse_traverse(&toks);
65 }
66 if upper.starts_with("GRAPH NEIGHBORS ") {
67 return variants::parse_neighbors(&toks);
68 }
69 if upper.starts_with("GRAPH PATH ") {
70 return variants::parse_path(&toks);
71 }
72 if upper.starts_with("GRAPH ALGO ") {
73 return variants::parse_algo(&toks);
74 }
75 if upper.starts_with("GRAPH RAG FUSION ") {
76 return variants::parse_rag_fusion(&toks, trimmed);
77 }
78
79 None
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use crate::ddl_ast::statement::{GraphDirection, GraphProperties};
86
87 #[test]
88 fn parse_graph_insert_edge_keyword_shaped_ids() {
89 let stmt =
90 try_parse("GRAPH INSERT EDGE IN 'myedges' FROM 'TO' TO 'FROM' TYPE 'LABEL'").unwrap();
91 match stmt {
92 NodedbStatement::Graph(GraphStmt::GraphInsertEdge {
93 collection,
94 src,
95 dst,
96 label,
97 properties,
98 }) => {
99 assert_eq!(collection, "myedges");
100 assert_eq!(src, "TO");
101 assert_eq!(dst, "FROM");
102 assert_eq!(label, "LABEL");
103 assert_eq!(properties, GraphProperties::None);
104 }
105 other => panic!("expected GraphInsertEdge, got {other:?}"),
106 }
107 }
108
109 #[test]
110 fn parse_graph_delete_edge_with_collection() {
111 let stmt = try_parse("GRAPH DELETE EDGE IN 'myedges' FROM 'a' TO 'b' TYPE 'l'").unwrap();
112 match stmt {
113 NodedbStatement::Graph(GraphStmt::GraphDeleteEdge {
114 collection,
115 src,
116 dst,
117 label,
118 }) => {
119 assert_eq!(collection, "myedges");
120 assert_eq!(src, "a");
121 assert_eq!(dst, "b");
122 assert_eq!(label, "l");
123 }
124 other => panic!("expected GraphDeleteEdge, got {other:?}"),
125 }
126 }
127
128 #[test]
129 fn parse_graph_insert_edge_missing_collection_returns_none() {
130 let result = try_parse("GRAPH INSERT EDGE FROM 'a' TO 'b' TYPE 'l'");
131 assert!(
132 result.is_none(),
133 "missing IN <collection> must not produce a statement"
134 );
135 }
136
137 #[test]
138 fn parse_graph_insert_edge_with_object_properties() {
139 let stmt = try_parse(
140 "GRAPH INSERT EDGE IN 'edges' FROM 'a' TO 'b' TYPE 'l' PROPERTIES { note: '} DEPTH 999' }",
141 )
142 .unwrap();
143 match stmt {
144 NodedbStatement::Graph(GraphStmt::GraphInsertEdge {
145 collection,
146 properties,
147 ..
148 }) => {
149 assert_eq!(collection, "edges");
150 match properties {
151 GraphProperties::Object(s) => assert!(s.contains("} DEPTH 999")),
152 other => panic!("expected Object properties, got {other:?}"),
153 }
154 }
155 other => panic!("expected GraphInsertEdge, got {other:?}"),
156 }
157 }
158
159 #[test]
160 fn parse_graph_traverse_keyword_substring_id() {
161 let stmt =
162 try_parse("GRAPH TRAVERSE FROM 'node_with_DEPTH_in_name' DEPTH 2 LABEL 'l'").unwrap();
163 match stmt {
164 NodedbStatement::Graph(GraphStmt::GraphTraverse { start, depth, .. }) => {
165 assert_eq!(start, "node_with_DEPTH_in_name");
166 assert_eq!(depth, 2);
167 }
168 other => panic!("expected GraphTraverse, got {other:?}"),
169 }
170 }
171
172 #[test]
173 fn parse_graph_path() {
174 let stmt = try_parse("GRAPH PATH FROM 'a' TO 'b' MAX_DEPTH 5 LABEL 'l'").unwrap();
175 match stmt {
176 NodedbStatement::Graph(GraphStmt::GraphPath {
177 src,
178 dst,
179 max_depth,
180 edge_label,
181 }) => {
182 assert_eq!(src, "a");
183 assert_eq!(dst, "b");
184 assert_eq!(max_depth, 5);
185 assert_eq!(edge_label.as_deref(), Some("l"));
186 }
187 other => panic!("expected GraphPath, got {other:?}"),
188 }
189 }
190
191 #[test]
192 fn parse_graph_labels_list() {
193 let stmt = try_parse("GRAPH LABEL 'alice' AS 'Person', 'User'").unwrap();
194 match stmt {
195 NodedbStatement::Graph(GraphStmt::GraphSetLabels {
196 node_id,
197 labels,
198 remove,
199 }) => {
200 assert_eq!(node_id, "alice");
201 assert_eq!(labels, vec!["Person".to_string(), "User".to_string()]);
202 assert!(!remove);
203 }
204 other => panic!("expected GraphSetLabels, got {other:?}"),
205 }
206 }
207
208 #[test]
209 fn parse_graph_algo_pagerank() {
210 let stmt = try_parse("GRAPH ALGO PAGERANK ON users ITERATIONS 5 DAMPING 0.85").unwrap();
211 match stmt {
212 NodedbStatement::Graph(GraphStmt::GraphAlgo {
213 algorithm,
214 collection,
215 damping,
216 max_iterations,
217 ..
218 }) => {
219 assert_eq!(algorithm, "PAGERANK");
220 assert_eq!(collection, "users");
221 assert_eq!(damping, Some(0.85));
222 assert_eq!(max_iterations, Some(5));
223 }
224 other => panic!("expected GraphAlgo, got {other:?}"),
225 }
226 }
227
228 #[test]
229 fn parse_graph_algo_personalization() {
230 let stmt = try_parse(
231 r#"GRAPH ALGO PAGERANK ON 'users' DAMPING 0.9 PERSONALIZATION {"alice": 1.0, "bob": 0.5}"#,
232 )
233 .unwrap();
234 match stmt {
235 NodedbStatement::Graph(GraphStmt::GraphAlgo {
236 algorithm,
237 collection,
238 damping,
239 personalization,
240 ..
241 }) => {
242 assert_eq!(algorithm, "PAGERANK");
243 assert_eq!(collection, "users");
244 assert_eq!(damping, Some(0.9));
245 let raw = personalization.expect("personalization object present");
246 assert!(raw.contains("alice"));
247 assert!(raw.contains("bob"));
248 let map: std::collections::HashMap<String, f64> = sonic_rs::from_str(&raw).unwrap();
250 assert_eq!(map.get("alice"), Some(&1.0));
251 assert_eq!(map.get("bob"), Some(&0.5));
252 }
253 other => panic!("expected GraphAlgo, got {other:?}"),
254 }
255 }
256
257 #[test]
258 fn parse_match_query_captures_raw() {
259 let stmt = try_parse("MATCH (x)-[:l]->(y) RETURN x, y").unwrap();
260 match stmt {
261 NodedbStatement::Graph(GraphStmt::MatchQuery { body }) => {
262 assert!(body.starts_with("MATCH"));
263 }
264 other => panic!("expected MatchQuery, got {other:?}"),
265 }
266 }
267
268 #[test]
269 fn non_graph_returns_none() {
270 assert!(try_parse("SELECT * FROM users").is_none());
271 assert!(try_parse("CREATE COLLECTION users").is_none());
272 }
273
274 #[test]
277 fn parse_rag_fusion_full_syntax() {
278 let stmt = try_parse(
279 "GRAPH RAG FUSION ON entities \
280 QUERY ARRAY[0.1, 0.2, 0.3] \
281 VECTOR_TOP_K 50 \
282 EXPANSION_DEPTH 2 \
283 EDGE_LABEL 'related_to' \
284 FINAL_TOP_K 10 \
285 RRF_K (60.0, 35.0)",
286 )
287 .unwrap();
288 match stmt {
289 NodedbStatement::Graph(GraphStmt::GraphRagFusion { collection, params }) => {
290 assert_eq!(collection, "entities");
291 let v = params.query_vector.expect("QUERY ARRAY parsed");
292 assert_eq!(v.len(), 3);
293 assert!((v[0] - 0.1f32).abs() < 1e-5);
294 assert_eq!(params.vector_top_k, Some(50));
295 assert_eq!(params.expansion_depth, Some(2));
296 assert_eq!(params.edge_label.as_deref(), Some("related_to"));
297 assert_eq!(params.final_top_k, Some(10));
298 let (k1, k2) = params.rrf_k.unwrap();
299 assert!((k1 - 60.0).abs() < 1e-10);
300 assert!((k2 - 35.0).abs() < 1e-10);
301 }
302 other => panic!("expected GraphRagFusion, got {other:?}"),
303 }
304 }
305
306 #[test]
307 fn parse_rag_fusion_minimal_defaults_to_none() {
308 let stmt = try_parse("GRAPH RAG FUSION ON mycol QUERY ARRAY[1.0, 0.0]").unwrap();
309 match stmt {
310 NodedbStatement::Graph(GraphStmt::GraphRagFusion { collection, params }) => {
311 assert_eq!(collection, "mycol");
312 assert!(params.query_vector.is_some());
313 assert_eq!(params.vector_top_k, None);
314 assert_eq!(params.expansion_depth, None);
315 assert_eq!(params.edge_label, None);
316 assert_eq!(params.final_top_k, None);
317 assert_eq!(params.rrf_k, None);
318 assert_eq!(params.vector_field, None);
319 assert_eq!(params.direction, None);
320 assert_eq!(params.max_visited, None);
321 }
322 other => panic!("expected GraphRagFusion, got {other:?}"),
323 }
324 }
325
326 #[test]
327 fn parse_rag_fusion_direction_and_max_visited() {
328 let stmt =
329 try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] DIRECTION both MAX_VISITED 500")
330 .unwrap();
331 match stmt {
332 NodedbStatement::Graph(GraphStmt::GraphRagFusion { params, .. }) => {
333 assert_eq!(params.direction, Some(GraphDirection::Both));
334 assert_eq!(params.max_visited, Some(500));
335 }
336 other => panic!("expected GraphRagFusion, got {other:?}"),
337 }
338 }
339
340 #[test]
341 fn parse_rag_fusion_vector_field_is_captured() {
342 let stmt =
343 try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] VECTOR_FIELD 'embedding'").unwrap();
344 match stmt {
345 NodedbStatement::Graph(GraphStmt::GraphRagFusion { params, .. }) => {
346 assert_eq!(params.vector_field.as_deref(), Some("embedding"));
347 }
348 other => panic!("expected GraphRagFusion, got {other:?}"),
349 }
350 }
351
352 #[test]
353 fn parse_rag_fusion_rrf_k_both_values_captured() {
354 let stmt = try_parse("GRAPH RAG FUSION ON col QUERY ARRAY[0.5] RRF_K (1.0, 99.5)").unwrap();
355 match stmt {
356 NodedbStatement::Graph(GraphStmt::GraphRagFusion { params, .. }) => {
357 let (k1, k2) = params.rrf_k.expect("RRF_K must be parsed");
358 assert!((k1 - 1.0).abs() < 1e-10, "vector_k must be 1.0, got {k1}");
359 assert!((k2 - 99.5).abs() < 1e-10, "graph_k must be 99.5, got {k2}");
360 }
361 other => panic!("expected GraphRagFusion, got {other:?}"),
362 }
363 }
364
365 #[test]
366 fn parse_rag_fusion_missing_collection_returns_none() {
367 let result = try_parse("GRAPH RAG FUSION QUERY ARRAY[0.1] VECTOR_TOP_K 5");
369 assert!(result.is_none(), "missing ON <collection> must return None");
370 }
371}