Skip to main content

krishiv_sql/
unnest_sql.rs

1//! E5.2 — LATERAL / UNNEST SQL pre-processing.
2//!
3//! DataFusion 53 supports `UNNEST` inside `SELECT` and `FROM` clauses for
4//! fixed-size arrays. This module provides:
5//!
6//! 1. **Detection**: identify UNNEST calls in SQL text before passing to DataFusion.
7//! 2. **Rewriter**: normalise the common `LATERAL UNNEST` idiom to a canonical
8//!    form that DataFusion can plan (`CROSS JOIN UNNEST`).
9//! 3. **NodeOp builder**: return a `NodeOp::Unnest` descriptor so the Krishiv
10//!    plan layer can record the unnest operator.
11
12use krishiv_plan::NodeOp;
13
14// ── Detection ─────────────────────────────────────────────────────────────────
15
16/// Returns `true` if `sql` contains an `UNNEST` call (case-insensitive).
17pub fn contains_unnest(sql: &str) -> bool {
18    let upper = sql.to_ascii_uppercase();
19    upper.contains("UNNEST(") || upper.contains("UNNEST (")
20}
21
22/// Returns `true` if `sql` contains a `LATERAL` keyword (case-insensitive).
23pub fn contains_lateral(sql: &str) -> bool {
24    sql.to_ascii_uppercase().contains(" LATERAL ")
25}
26
27// ── Rewriter ──────────────────────────────────────────────────────────────────
28
29/// Rewrite `LATERAL UNNEST(...)` idioms to a form DataFusion understands.
30///
31/// Normalises:
32/// ```sql
33/// SELECT * FROM t, LATERAL UNNEST(t.tags) AS tag(value)
34/// ```
35/// to:
36/// ```sql
37/// SELECT * FROM t CROSS JOIN UNNEST(t.tags) AS tag(value)
38/// ```
39///
40/// Queries that do not contain `LATERAL UNNEST` are returned unchanged.
41///
42/// # Limitations
43/// Only handles the common single-table `LATERAL UNNEST` idiom. Complex uses
44/// (multiple LATERAL joins, LATERAL subqueries) are passed through to DataFusion
45/// which will either handle them or return a clear error.
46pub fn rewrite_lateral_unnest(sql: &str) -> String {
47    if !contains_lateral(sql) || !contains_unnest(sql) {
48        return sql.to_owned();
49    }
50
51    let patterns: &[(&str, &str)] = &[
52        (", LATERAL UNNEST(", " CROSS JOIN UNNEST("),
53        (",LATERAL UNNEST(", " CROSS JOIN UNNEST("),
54        (" LATERAL UNNEST(", " CROSS JOIN UNNEST("),
55    ];
56
57    let mut result = sql.to_owned();
58    for (from, to) in patterns {
59        let upper_from = from.to_ascii_uppercase();
60        // Compute the uppercase view once per pattern pass; track the search
61        // position to avoid re-scanning the prefix on each replacement.
62        let mut search_start = 0;
63        loop {
64            let upper_result = result[search_start..].to_ascii_uppercase();
65            match upper_result.find(&upper_from) {
66                Some(rel_pos) => {
67                    let pos = search_start + rel_pos;
68                    result.replace_range(pos..pos + from.len(), to);
69                    search_start = pos + to.len();
70                }
71                None => break,
72            }
73        }
74    }
75    result
76}
77
78// ── NodeOp builder ────────────────────────────────────────────────────────────
79
80/// Build a `NodeOp::Unnest` descriptor.
81///
82/// * `array_column` — the source column that contains the array.
83/// * `output_column` — the name of the column produced for each element.
84/// * `with_ordinality` — when `true` an extra `ordinality` column (`u64`) is
85///   appended with the 1-based position of each element.
86pub fn build_unnest_op(
87    array_column: impl Into<String>,
88    output_column: impl Into<String>,
89    with_ordinality: bool,
90) -> NodeOp {
91    NodeOp::Unnest {
92        array_column: array_column.into(),
93        output_column: output_column.into(),
94        with_ordinality,
95    }
96}
97
98// ── Tests ─────────────────────────────────────────────────────────────────────
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn detects_unnest_call() {
106        assert!(contains_unnest("SELECT UNNEST(tags) FROM t"));
107        assert!(contains_unnest(
108            "SELECT * FROM t, LATERAL UNNEST(t.ids) AS id(v)"
109        ));
110        assert!(!contains_unnest("SELECT * FROM t WHERE x = 1"));
111    }
112
113    #[test]
114    fn detects_lateral() {
115        assert!(contains_lateral(
116            "SELECT * FROM t, LATERAL UNNEST(t.ids) AS id(v)"
117        ));
118        assert!(!contains_lateral(
119            "SELECT * FROM t CROSS JOIN UNNEST(t.ids)"
120        ));
121    }
122
123    #[test]
124    fn rewrites_lateral_unnest_with_comma() {
125        let sql = "SELECT * FROM t, LATERAL UNNEST(t.tags) AS tag(value)";
126        let rewritten = rewrite_lateral_unnest(sql);
127        assert!(!rewritten.to_ascii_uppercase().contains(" LATERAL "));
128        assert!(rewritten.to_ascii_uppercase().contains("CROSS JOIN UNNEST"));
129    }
130
131    #[test]
132    fn rewrites_lateral_unnest_preserves_alias() {
133        let sql = "SELECT t.id, tag.value FROM t, LATERAL UNNEST(t.tags) AS tag(value)";
134        let rewritten = rewrite_lateral_unnest(sql);
135        assert!(
136            rewritten.contains("tag(value)"),
137            "alias preserved: {rewritten}"
138        );
139    }
140
141    #[test]
142    fn passthrough_plain_unnest() {
143        let sql = "SELECT id, UNNEST(tags) AS tag FROM t";
144        let rewritten = rewrite_lateral_unnest(sql);
145        assert_eq!(rewritten, sql, "unchanged: {rewritten}");
146    }
147
148    #[test]
149    fn passthrough_non_lateral_unnest_unchanged() {
150        let sql = "SELECT * FROM t CROSS JOIN UNNEST(t.ids) AS elem(v)";
151        let rewritten = rewrite_lateral_unnest(sql);
152        assert_eq!(rewritten, sql);
153    }
154
155    #[test]
156    fn build_unnest_op_returns_correct_variant() {
157        let op = build_unnest_op("tags", "tag", false);
158        match op {
159            NodeOp::Unnest {
160                array_column,
161                output_column,
162                with_ordinality,
163            } => {
164                assert_eq!(array_column, "tags");
165                assert_eq!(output_column, "tag");
166                assert!(!with_ordinality);
167            }
168            _ => panic!("expected Unnest variant"),
169        }
170    }
171
172    #[test]
173    fn build_unnest_op_with_ordinality() {
174        let op = build_unnest_op("items", "item", true);
175        match op {
176            NodeOp::Unnest {
177                with_ordinality, ..
178            } => assert!(with_ordinality),
179            _ => panic!("expected Unnest"),
180        }
181    }
182}