Skip to main content

nodedb_sql/ddl_ast/
graph_parse.rs

1//! Typed parsing for the graph DSL (`GRAPH ...`, `MATCH ...`).
2//!
3//! The handler layer historically parsed graph statements with
4//! `upper.find("KEYWORD")` substring matching, which collapsed when
5//! a node id, label, or property value shadowed a DSL keyword. This
6//! module is the structural replacement: a quote- and brace-aware
7//! tokeniser feeds a variant-building parser that produces a typed
8//! [`NodedbStatement`]. Every graph DSL command flows through here
9//! before reaching a pgwire handler, so the handlers never touch
10//! raw SQL again.
11//!
12//! Values collected here are intentionally unvalidated — numeric
13//! bounds, absent-but-required fields, and engine-level caps are
14//! enforced at the pgwire boundary where the error response is
15//! formed. That keeps this module free of `pgwire` dependencies
16//! and out of the `nodedb` → `nodedb-sql` dependency edge.
17
18use std::borrow::Cow;
19
20use super::statement::{GraphDirection, GraphProperties, NodedbStatement};
21
22/// Entry point: returns `Some` when `sql` starts with a graph DSL
23/// keyword (`GRAPH ...`, `MATCH ...`, `OPTIONAL MATCH ...`), else
24/// `None` so the main DDL parser can continue trying other cases.
25///
26/// Within those prefixes a malformed statement still returns
27/// `Some(stmt)` whenever a sensible variant can be produced; the
28/// pgwire handler decides whether required fields are present.
29/// Statements that cannot be shaped into any variant return `None`
30/// so the legacy string-prefix router can still handle them during
31/// the migration window.
32pub fn try_parse(sql: &str) -> Option<NodedbStatement> {
33    let trimmed = sql.trim();
34    let upper = trimmed.to_ascii_uppercase();
35
36    if upper.starts_with("MATCH ") || upper.starts_with("OPTIONAL MATCH ") {
37        return Some(NodedbStatement::MatchQuery {
38            raw_sql: trimmed.to_string(),
39        });
40    }
41
42    if !upper.starts_with("GRAPH ") {
43        return None;
44    }
45
46    let toks = tokenize(trimmed);
47
48    if upper.starts_with("GRAPH INSERT EDGE ") {
49        return parse_insert_edge(&toks);
50    }
51    if upper.starts_with("GRAPH DELETE EDGE ") {
52        return parse_delete_edge(&toks);
53    }
54    if upper.starts_with("GRAPH LABEL ") {
55        return parse_set_labels(&toks, false);
56    }
57    if upper.starts_with("GRAPH UNLABEL ") {
58        return parse_set_labels(&toks, true);
59    }
60    if upper.starts_with("GRAPH TRAVERSE ") {
61        return parse_traverse(&toks);
62    }
63    if upper.starts_with("GRAPH NEIGHBORS ") {
64        return parse_neighbors(&toks);
65    }
66    if upper.starts_with("GRAPH PATH ") {
67        return parse_path(&toks);
68    }
69    if upper.starts_with("GRAPH ALGO ") {
70        return parse_algo(&toks);
71    }
72
73    None
74}
75
76// ── Variant builders ─────────────────────────────────────────────
77
78fn parse_insert_edge(toks: &[Tok<'_>]) -> Option<NodedbStatement> {
79    let src = quoted_after(toks, "FROM")?;
80    let dst = quoted_after(toks, "TO")?;
81    let label = quoted_after(toks, "TYPE")?;
82    let properties = extract_properties(toks);
83    Some(NodedbStatement::GraphInsertEdge {
84        src,
85        dst,
86        label,
87        properties,
88    })
89}
90
91fn parse_delete_edge(toks: &[Tok<'_>]) -> Option<NodedbStatement> {
92    let src = quoted_after(toks, "FROM")?;
93    let dst = quoted_after(toks, "TO")?;
94    let label = quoted_after(toks, "TYPE")?;
95    Some(NodedbStatement::GraphDeleteEdge { src, dst, label })
96}
97
98fn parse_set_labels(toks: &[Tok<'_>], remove: bool) -> Option<NodedbStatement> {
99    let keyword = if remove { "UNLABEL" } else { "LABEL" };
100    let node_id = quoted_after(toks, keyword)?;
101    let labels = quoted_list_after(toks, "AS");
102    Some(NodedbStatement::GraphSetLabels {
103        node_id,
104        labels,
105        remove,
106    })
107}
108
109fn parse_traverse(toks: &[Tok<'_>]) -> Option<NodedbStatement> {
110    let start = quoted_after(toks, "FROM")?;
111    let depth = usize_after(toks, "DEPTH").unwrap_or(2);
112    let edge_label = quoted_after(toks, "LABEL");
113    let direction = direction_after(toks);
114    Some(NodedbStatement::GraphTraverse {
115        start,
116        depth,
117        edge_label,
118        direction,
119    })
120}
121
122fn parse_neighbors(toks: &[Tok<'_>]) -> Option<NodedbStatement> {
123    let node = quoted_after(toks, "OF")?;
124    let edge_label = quoted_after(toks, "LABEL");
125    let direction = direction_after(toks);
126    Some(NodedbStatement::GraphNeighbors {
127        node,
128        edge_label,
129        direction,
130    })
131}
132
133fn parse_path(toks: &[Tok<'_>]) -> Option<NodedbStatement> {
134    let src = quoted_after(toks, "FROM")?;
135    let dst = quoted_after(toks, "TO")?;
136    let max_depth = usize_after(toks, "MAX_DEPTH").unwrap_or(10);
137    let edge_label = quoted_after(toks, "LABEL");
138    Some(NodedbStatement::GraphPath {
139        src,
140        dst,
141        max_depth,
142        edge_label,
143    })
144}
145
146fn parse_algo(toks: &[Tok<'_>]) -> Option<NodedbStatement> {
147    // Algorithm name is the first `Word` token after `ALGO`.
148    let algorithm = find_keyword(toks, "ALGO").and_then(|i| match toks.get(i + 1)? {
149        Tok::Word(w) => Some(w.to_ascii_uppercase()),
150        _ => None,
151    })?;
152    let collection = word_after(toks, "ON")?.to_lowercase();
153    Some(NodedbStatement::GraphAlgo {
154        algorithm,
155        collection,
156        damping: float_after(toks, "DAMPING"),
157        tolerance: float_after(toks, "TOLERANCE"),
158        resolution: float_after(toks, "RESOLUTION"),
159        max_iterations: usize_after(toks, "ITERATIONS"),
160        sample_size: usize_after(toks, "SAMPLE"),
161        source_node: quoted_after(toks, "FROM"),
162        direction: word_after(toks, "DIRECTION"),
163        mode: word_after(toks, "MODE"),
164    })
165}
166
167fn extract_properties(toks: &[Tok<'_>]) -> GraphProperties {
168    let Some(pos) = find_keyword(toks, "PROPERTIES") else {
169        return GraphProperties::None;
170    };
171    match toks.get(pos + 1) {
172        Some(Tok::Object(obj_str)) => GraphProperties::Object((*obj_str).to_string()),
173        Some(Tok::Quoted(s)) => GraphProperties::Quoted(s.clone().into_owned()),
174        _ => GraphProperties::None,
175    }
176}
177
178fn direction_after(toks: &[Tok<'_>]) -> GraphDirection {
179    match word_after(toks, "DIRECTION")
180        .as_deref()
181        .map(str::to_ascii_uppercase)
182        .as_deref()
183    {
184        Some("IN") => GraphDirection::In,
185        Some("BOTH") => GraphDirection::Both,
186        _ => GraphDirection::Out,
187    }
188}
189
190// ── Tokeniser ────────────────────────────────────────────────────
191
192enum Tok<'a> {
193    Word(&'a str),
194    Quoted(Cow<'a, str>),
195    /// Brace-balanced object literal, including the outer braces.
196    Object(&'a str),
197}
198
199fn tokenize(sql: &str) -> Vec<Tok<'_>> {
200    let bytes = sql.as_bytes();
201    let mut out = Vec::new();
202    let mut i = 0;
203    while i < bytes.len() {
204        let b = bytes[i];
205        if b.is_ascii_whitespace() || b == b',' || b == b';' || b == b'(' || b == b')' {
206            i += 1;
207            continue;
208        }
209        if b == b'\'' {
210            i = consume_quoted(sql, bytes, i, &mut out);
211            continue;
212        }
213        if b == b'{' {
214            i = consume_object(sql, bytes, i, &mut out);
215            continue;
216        }
217        i = consume_word(sql, bytes, i, &mut out);
218    }
219    out
220}
221
222fn consume_quoted<'a>(sql: &'a str, bytes: &[u8], start: usize, out: &mut Vec<Tok<'a>>) -> usize {
223    let content_start = start + 1;
224    let mut j = content_start;
225    let mut has_escape = false;
226    while j < bytes.len() {
227        if bytes[j] == b'\'' {
228            if j + 1 < bytes.len() && bytes[j + 1] == b'\'' {
229                has_escape = true;
230                j += 2;
231                continue;
232            }
233            break;
234        }
235        j += 1;
236    }
237    let slice = &sql[content_start..j];
238    let content = if has_escape {
239        Cow::Owned(slice.replace("''", "'"))
240    } else {
241        Cow::Borrowed(slice)
242    };
243    out.push(Tok::Quoted(content));
244    if j < bytes.len() { j + 1 } else { j }
245}
246
247fn consume_object<'a>(sql: &'a str, bytes: &[u8], start: usize, out: &mut Vec<Tok<'a>>) -> usize {
248    let mut depth = 0i32;
249    let mut j = start;
250    let mut in_quote = false;
251    while j < bytes.len() {
252        let c = bytes[j];
253        if in_quote {
254            if c == b'\'' {
255                if j + 1 < bytes.len() && bytes[j + 1] == b'\'' {
256                    j += 2;
257                    continue;
258                }
259                in_quote = false;
260            }
261        } else {
262            match c {
263                b'\'' => in_quote = true,
264                b'{' => depth += 1,
265                b'}' => {
266                    depth -= 1;
267                    if depth == 0 {
268                        j += 1;
269                        break;
270                    }
271                }
272                _ => {}
273            }
274        }
275        j += 1;
276    }
277    out.push(Tok::Object(&sql[start..j]));
278    j
279}
280
281fn consume_word<'a>(sql: &'a str, bytes: &[u8], start: usize, out: &mut Vec<Tok<'a>>) -> usize {
282    let mut j = start;
283    while j < bytes.len() {
284        let c = bytes[j];
285        if c.is_ascii_whitespace()
286            || c == b'\''
287            || c == b'{'
288            || c == b','
289            || c == b';'
290            || c == b'('
291            || c == b')'
292        {
293            break;
294        }
295        j += 1;
296    }
297    if j > start {
298        out.push(Tok::Word(&sql[start..j]));
299        j
300    } else {
301        start + 1
302    }
303}
304
305// ── Token-level extraction helpers ───────────────────────────────
306
307fn find_keyword(toks: &[Tok<'_>], keyword: &str) -> Option<usize> {
308    toks.iter()
309        .position(|t| matches!(t, Tok::Word(w) if w.eq_ignore_ascii_case(keyword)))
310}
311
312fn quoted_after(toks: &[Tok<'_>], keyword: &str) -> Option<String> {
313    let pos = find_keyword(toks, keyword)?;
314    match toks.get(pos + 1)? {
315        Tok::Quoted(s) => Some(s.clone().into_owned()),
316        Tok::Word(w) => Some((*w).to_string()),
317        Tok::Object(_) => None,
318    }
319}
320
321fn quoted_list_after(toks: &[Tok<'_>], keyword: &str) -> Vec<String> {
322    let Some(pos) = find_keyword(toks, keyword) else {
323        return Vec::new();
324    };
325    toks[pos + 1..]
326        .iter()
327        .map_while(|t| match t {
328            Tok::Quoted(s) => Some(s.clone().into_owned()),
329            _ => None,
330        })
331        .collect()
332}
333
334fn word_after(toks: &[Tok<'_>], keyword: &str) -> Option<String> {
335    let pos = find_keyword(toks, keyword)?;
336    if let Tok::Word(w) = toks.get(pos + 1)? {
337        Some((*w).to_string())
338    } else {
339        None
340    }
341}
342
343fn usize_after(toks: &[Tok<'_>], keyword: &str) -> Option<usize> {
344    word_after(toks, keyword)?.parse().ok()
345}
346
347fn float_after(toks: &[Tok<'_>], keyword: &str) -> Option<f64> {
348    word_after(toks, keyword)?.parse().ok()
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn parse_graph_insert_edge_keyword_shaped_ids() {
357        let stmt = try_parse("GRAPH INSERT EDGE FROM 'TO' TO 'FROM' TYPE 'LABEL'").unwrap();
358        match stmt {
359            NodedbStatement::GraphInsertEdge {
360                src,
361                dst,
362                label,
363                properties,
364            } => {
365                assert_eq!(src, "TO");
366                assert_eq!(dst, "FROM");
367                assert_eq!(label, "LABEL");
368                assert_eq!(properties, GraphProperties::None);
369            }
370            other => panic!("expected GraphInsertEdge, got {other:?}"),
371        }
372    }
373
374    #[test]
375    fn parse_graph_insert_edge_with_object_properties() {
376        let stmt = try_parse(
377            "GRAPH INSERT EDGE FROM 'a' TO 'b' TYPE 'l' PROPERTIES { note: '} DEPTH 999' }",
378        )
379        .unwrap();
380        match stmt {
381            NodedbStatement::GraphInsertEdge { properties, .. } => match properties {
382                GraphProperties::Object(s) => assert!(s.contains("} DEPTH 999")),
383                other => panic!("expected Object properties, got {other:?}"),
384            },
385            other => panic!("expected GraphInsertEdge, got {other:?}"),
386        }
387    }
388
389    #[test]
390    fn parse_graph_traverse_keyword_substring_id() {
391        let stmt =
392            try_parse("GRAPH TRAVERSE FROM 'node_with_DEPTH_in_name' DEPTH 2 LABEL 'l'").unwrap();
393        match stmt {
394            NodedbStatement::GraphTraverse { start, depth, .. } => {
395                assert_eq!(start, "node_with_DEPTH_in_name");
396                assert_eq!(depth, 2);
397            }
398            other => panic!("expected GraphTraverse, got {other:?}"),
399        }
400    }
401
402    #[test]
403    fn parse_graph_path() {
404        let stmt = try_parse("GRAPH PATH FROM 'a' TO 'b' MAX_DEPTH 5 LABEL 'l'").unwrap();
405        match stmt {
406            NodedbStatement::GraphPath {
407                src,
408                dst,
409                max_depth,
410                edge_label,
411            } => {
412                assert_eq!(src, "a");
413                assert_eq!(dst, "b");
414                assert_eq!(max_depth, 5);
415                assert_eq!(edge_label.as_deref(), Some("l"));
416            }
417            other => panic!("expected GraphPath, got {other:?}"),
418        }
419    }
420
421    #[test]
422    fn parse_graph_labels_list() {
423        let stmt = try_parse("GRAPH LABEL 'alice' AS 'Person', 'User'").unwrap();
424        match stmt {
425            NodedbStatement::GraphSetLabels {
426                node_id,
427                labels,
428                remove,
429            } => {
430                assert_eq!(node_id, "alice");
431                assert_eq!(labels, vec!["Person".to_string(), "User".to_string()]);
432                assert!(!remove);
433            }
434            other => panic!("expected GraphSetLabels, got {other:?}"),
435        }
436    }
437
438    #[test]
439    fn parse_graph_algo_pagerank() {
440        let stmt = try_parse("GRAPH ALGO PAGERANK ON users ITERATIONS 5 DAMPING 0.85").unwrap();
441        match stmt {
442            NodedbStatement::GraphAlgo {
443                algorithm,
444                collection,
445                damping,
446                max_iterations,
447                ..
448            } => {
449                assert_eq!(algorithm, "PAGERANK");
450                assert_eq!(collection, "users");
451                assert_eq!(damping, Some(0.85));
452                assert_eq!(max_iterations, Some(5));
453            }
454            other => panic!("expected GraphAlgo, got {other:?}"),
455        }
456    }
457
458    #[test]
459    fn parse_match_query_captures_raw() {
460        let stmt = try_parse("MATCH (x)-[:l]->(y) RETURN x, y").unwrap();
461        match stmt {
462            NodedbStatement::MatchQuery { raw_sql } => {
463                assert!(raw_sql.starts_with("MATCH"));
464            }
465            other => panic!("expected MatchQuery, got {other:?}"),
466        }
467    }
468
469    #[test]
470    fn non_graph_returns_none() {
471        assert!(try_parse("SELECT * FROM users").is_none());
472        assert!(try_parse("CREATE COLLECTION users").is_none());
473    }
474}