krishiv_sql/
unnest_sql.rs1use krishiv_plan::NodeOp;
13
14pub fn contains_unnest(sql: &str) -> bool {
18 let upper = sql.to_ascii_uppercase();
19 upper.contains("UNNEST(") || upper.contains("UNNEST (")
20}
21
22pub fn contains_lateral(sql: &str) -> bool {
24 sql.to_ascii_uppercase().contains(" LATERAL ")
25}
26
27pub fn rewrite_lateral_unnest(sql: &str) -> String {
47 if !contains_lateral(sql) || !contains_unnest(sql) {
48 return sql.to_owned();
49 }
50
51 let patterns: &[(&str, &str)] = &[
52 (", LATERAL UNNEST(", " CROSS JOIN UNNEST("),
53 (",LATERAL UNNEST(", " CROSS JOIN UNNEST("),
54 (" LATERAL UNNEST(", " CROSS JOIN UNNEST("),
55 ];
56
57 let mut result = sql.to_owned();
58 for (from, to) in patterns {
59 let upper_from = from.to_ascii_uppercase();
60 let mut search_start = 0;
63 loop {
64 let upper_result = result[search_start..].to_ascii_uppercase();
65 match upper_result.find(&upper_from) {
66 Some(rel_pos) => {
67 let pos = search_start + rel_pos;
68 result.replace_range(pos..pos + from.len(), to);
69 search_start = pos + to.len();
70 }
71 None => break,
72 }
73 }
74 }
75 result
76}
77
78pub fn build_unnest_op(
87 array_column: impl Into<String>,
88 output_column: impl Into<String>,
89 with_ordinality: bool,
90) -> NodeOp {
91 NodeOp::Unnest {
92 array_column: array_column.into(),
93 output_column: output_column.into(),
94 with_ordinality,
95 }
96}
97
98#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn detects_unnest_call() {
106 assert!(contains_unnest("SELECT UNNEST(tags) FROM t"));
107 assert!(contains_unnest(
108 "SELECT * FROM t, LATERAL UNNEST(t.ids) AS id(v)"
109 ));
110 assert!(!contains_unnest("SELECT * FROM t WHERE x = 1"));
111 }
112
113 #[test]
114 fn detects_lateral() {
115 assert!(contains_lateral(
116 "SELECT * FROM t, LATERAL UNNEST(t.ids) AS id(v)"
117 ));
118 assert!(!contains_lateral(
119 "SELECT * FROM t CROSS JOIN UNNEST(t.ids)"
120 ));
121 }
122
123 #[test]
124 fn rewrites_lateral_unnest_with_comma() {
125 let sql = "SELECT * FROM t, LATERAL UNNEST(t.tags) AS tag(value)";
126 let rewritten = rewrite_lateral_unnest(sql);
127 assert!(!rewritten.to_ascii_uppercase().contains(" LATERAL "));
128 assert!(rewritten.to_ascii_uppercase().contains("CROSS JOIN UNNEST"));
129 }
130
131 #[test]
132 fn rewrites_lateral_unnest_preserves_alias() {
133 let sql = "SELECT t.id, tag.value FROM t, LATERAL UNNEST(t.tags) AS tag(value)";
134 let rewritten = rewrite_lateral_unnest(sql);
135 assert!(
136 rewritten.contains("tag(value)"),
137 "alias preserved: {rewritten}"
138 );
139 }
140
141 #[test]
142 fn passthrough_plain_unnest() {
143 let sql = "SELECT id, UNNEST(tags) AS tag FROM t";
144 let rewritten = rewrite_lateral_unnest(sql);
145 assert_eq!(rewritten, sql, "unchanged: {rewritten}");
146 }
147
148 #[test]
149 fn passthrough_non_lateral_unnest_unchanged() {
150 let sql = "SELECT * FROM t CROSS JOIN UNNEST(t.ids) AS elem(v)";
151 let rewritten = rewrite_lateral_unnest(sql);
152 assert_eq!(rewritten, sql);
153 }
154
155 #[test]
156 fn build_unnest_op_returns_correct_variant() {
157 let op = build_unnest_op("tags", "tag", false);
158 match op {
159 NodeOp::Unnest {
160 array_column,
161 output_column,
162 with_ordinality,
163 } => {
164 assert_eq!(array_column, "tags");
165 assert_eq!(output_column, "tag");
166 assert!(!with_ordinality);
167 }
168 _ => panic!("expected Unnest variant"),
169 }
170 }
171
172 #[test]
173 fn build_unnest_op_with_ordinality() {
174 let op = build_unnest_op("items", "item", true);
175 match op {
176 NodeOp::Unnest {
177 with_ordinality, ..
178 } => assert!(with_ordinality),
179 _ => panic!("expected Unnest"),
180 }
181 }
182}