1use std::collections::HashSet;
2
3use miette::Result;
4use pg_query::{NodeEnum, parse};
5use tracing::debug;
6
7use crate::schema_cache::SchemaCache;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
11pub struct ColumnRef {
12 pub schema: String,
13 pub table: String,
14 pub column: String,
15}
16
17pub fn extract_column_refs(
19 sql: &str,
20 schema_cache: Option<&SchemaCache>,
21) -> Result<Vec<ColumnRef>> {
22 let parse_result = match parse(sql) {
24 Ok(result) => result,
25 Err(e) => {
26 debug!("Failed to parse SQL for column extraction: {}", e);
27 return Ok(Vec::new());
28 }
29 };
30
31 let mut column_refs = Vec::new();
32 let mut context = ExtractionContext {
33 schema_cache,
34 column_refs: &mut column_refs,
35 table_aliases: Default::default(),
36 in_select_list: false,
37 };
38
39 for statement in parse_result.protobuf.stmts {
41 if let Some(stmt) = statement.stmt {
42 process_node(&stmt.node, &mut context);
43 }
44 }
45
46 let mut seen = HashSet::new();
48 column_refs.retain(|col_ref| seen.insert(col_ref.clone()));
49
50 Ok(column_refs)
51}
52
53struct ExtractionContext<'a> {
54 schema_cache: Option<&'a SchemaCache>,
55 column_refs: &'a mut Vec<ColumnRef>,
56 table_aliases: std::collections::HashMap<String, (String, String)>,
57 in_select_list: bool,
58}
59
60fn process_node(node: &Option<NodeEnum>, ctx: &mut ExtractionContext<'_>) {
61 let Some(node) = node else { return };
62
63 match node {
64 NodeEnum::SelectStmt(select) => {
65 for from_item in &select.from_clause {
67 if let Some(NodeEnum::RangeVar(range)) = &from_item.node {
68 let table_name = range.relname.clone();
69 let schema_name = if range.schemaname.is_empty() {
70 if let Some(cache) = ctx.schema_cache {
71 find_schema_for_table(cache, &table_name)
72 .unwrap_or_else(|| "public".to_string())
73 } else {
74 "public".to_string()
75 }
76 } else {
77 range.schemaname.clone()
78 };
79
80 let alias = if let Some(a) = &range.alias {
81 a.aliasname.clone()
82 } else {
83 table_name.clone()
84 };
85
86 ctx.table_aliases.insert(alias, (schema_name, table_name));
87 }
88 process_from_item(&from_item.node, ctx);
90 }
91
92 let old_in_select_list = ctx.in_select_list;
94 ctx.in_select_list = true;
95
96 for target in &select.target_list {
97 if let Some(NodeEnum::ResTarget(res)) = &target.node
98 && let Some(val) = &res.val
99 {
100 if let Some(NodeEnum::ColumnRef(_)) = &val.node {
102 process_node(&val.node, ctx);
103 } else if let Some(NodeEnum::AStar(_)) = &val.node {
104 expand_wildcard(None, ctx);
106 }
107 }
109 }
110
111 ctx.in_select_list = old_in_select_list;
112
113 if let Some(where_clause) = &select.where_clause {
115 process_node(&where_clause.node, ctx);
116 }
117
118 for group in &select.group_clause {
120 process_node(&group.node, ctx);
121 }
122
123 if let Some(having) = &select.having_clause {
125 process_node(&having.node, ctx);
126 }
127 }
128 NodeEnum::ColumnRef(col_ref) => {
129 process_column_ref(col_ref, ctx);
131 }
132 NodeEnum::AStar(_) => {
133 expand_wildcard(None, ctx);
135 }
136 NodeEnum::RangeVar(_) => {
137 }
139 NodeEnum::JoinExpr(join) => {
140 if let Some(larg) = &join.larg {
142 process_node(&larg.node, ctx);
143 }
144 if let Some(rarg) = &join.rarg {
145 process_node(&rarg.node, ctx);
146 }
147 if let Some(quals) = &join.quals {
149 process_node(&quals.node, ctx);
150 }
151 }
152 NodeEnum::AExpr(expr) => {
153 let old_in_select_list = ctx.in_select_list;
155 ctx.in_select_list = false;
156
157 if let Some(lexpr) = &expr.lexpr {
158 process_node(&lexpr.node, ctx);
159 }
160 if let Some(rexpr) = &expr.rexpr {
161 process_node(&rexpr.node, ctx);
162 }
163
164 ctx.in_select_list = old_in_select_list;
165 }
166 NodeEnum::BoolExpr(expr) => {
167 for arg in &expr.args {
169 process_node(&arg.node, ctx);
170 }
171 }
172 NodeEnum::FuncCall(_) => {
173 }
176 NodeEnum::SubLink(sublink) => {
177 if let Some(subselect) = &sublink.subselect {
179 process_node(&subselect.node, ctx);
180 }
181 }
182 NodeEnum::RangeSubselect(range_sub) => {
183 if let Some(subquery) = &range_sub.subquery {
185 process_node(&subquery.node, ctx);
186 }
187 }
188 _ => {
189 }
191 }
192}
193
194fn process_from_item(node: &Option<NodeEnum>, ctx: &mut ExtractionContext<'_>) {
195 let Some(node) = node else { return };
196
197 match node {
198 NodeEnum::RangeVar(range) => {
199 let table_name = range.relname.clone();
200 let schema_name = if range.schemaname.is_empty() {
201 if let Some(cache) = ctx.schema_cache {
202 find_schema_for_table(cache, &table_name)
203 .unwrap_or_else(|| "public".to_string())
204 } else {
205 "public".to_string()
206 }
207 } else {
208 range.schemaname.clone()
209 };
210
211 let alias = if let Some(a) = &range.alias {
212 a.aliasname.clone()
213 } else {
214 table_name.clone()
215 };
216
217 ctx.table_aliases.insert(alias, (schema_name, table_name));
218 }
219 NodeEnum::JoinExpr(join) => {
220 if let Some(larg) = &join.larg {
221 process_from_item(&larg.node, ctx);
222 }
223 if let Some(rarg) = &join.rarg {
224 process_from_item(&rarg.node, ctx);
225 }
226 }
227 NodeEnum::RangeSubselect(_) => {
228 }
230 _ => {}
231 }
232}
233
234fn process_column_ref(col_ref: &pg_query::protobuf::ColumnRef, ctx: &mut ExtractionContext<'_>) {
235 let fields: Vec<String> = col_ref
236 .fields
237 .iter()
238 .filter_map(|field| {
239 if let Some(NodeEnum::String(s)) = &field.node {
240 Some(s.sval.clone())
241 } else if let Some(NodeEnum::AStar(_)) = &field.node {
242 None } else {
244 None
245 }
246 })
247 .collect();
248
249 let has_star = col_ref
251 .fields
252 .iter()
253 .any(|field| matches!(&field.node, Some(NodeEnum::AStar(_))));
254
255 if has_star {
256 if !fields.is_empty() {
257 let table_name = &fields[0];
259 expand_wildcard(Some(table_name), ctx);
260 } else {
261 expand_wildcard(None, ctx);
263 }
264 return;
265 }
266
267 match fields.len() {
268 1 => {
269 let column_name = &fields[0];
271
272 if ctx.table_aliases.len() == 1
274 && let Some((schema, table)) = ctx.table_aliases.values().next()
275 {
276 ctx.column_refs.push(ColumnRef {
277 schema: schema.clone(),
278 table: table.clone(),
279 column: column_name.clone(),
280 });
281 }
282 }
284 2 => {
285 let table_or_alias = &fields[0];
287 let column_name = &fields[1];
288
289 if let Some((schema, table)) = ctx.table_aliases.get(table_or_alias) {
290 ctx.column_refs.push(ColumnRef {
291 schema: schema.clone(),
292 table: table.clone(),
293 column: column_name.clone(),
294 });
295 }
296 }
297 3 => {
298 let schema = &fields[0];
300 let table = &fields[1];
301 let column = &fields[2];
302 ctx.column_refs.push(ColumnRef {
303 schema: schema.clone(),
304 table: table.clone(),
305 column: column.clone(),
306 });
307 }
308 _ => {}
309 }
310}
311
312fn expand_wildcard(table_qualifier: Option<&str>, ctx: &mut ExtractionContext<'_>) {
313 let Some(cache) = ctx.schema_cache else {
314 return;
315 };
316
317 if let Some(table_name) = table_qualifier {
318 if let Some((schema, table)) = ctx.table_aliases.get(table_name)
320 && let Some(columns) = cache.columns_for_table(table)
321 {
322 for column in columns {
323 ctx.column_refs.push(ColumnRef {
324 schema: schema.clone(),
325 table: table.clone(),
326 column: column.clone(),
327 });
328 }
329 }
330 } else {
331 for (schema, table) in ctx.table_aliases.values() {
333 if let Some(columns) = cache.columns_for_table(table) {
334 for column in columns {
335 ctx.column_refs.push(ColumnRef {
336 schema: schema.clone(),
337 table: table.clone(),
338 column: column.clone(),
339 });
340 }
341 }
342 }
343 }
344}
345
346fn find_schema_for_table(cache: &SchemaCache, table: &str) -> Option<String> {
347 if cache
349 .tables
350 .get("public")
351 .is_some_and(|tables| tables.contains(&table.to_string()))
352 {
353 return Some("public".to_string());
354 }
355
356 for (schema_name, tables) in &cache.tables {
358 if tables.contains(&table.to_string()) {
359 return Some(schema_name.clone());
360 }
361 }
362
363 None
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 fn create_test_cache() -> SchemaCache {
371 let mut cache = SchemaCache::new();
372 cache
373 .tables
374 .insert("public".to_string(), vec!["patient".to_string()]);
375 cache.columns.insert(
376 "public.patient".to_string(),
377 vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
378 );
379 cache.columns.insert(
380 "patient".to_string(),
381 vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
382 );
383 cache
384 }
385
386 #[test]
387 fn test_parse_structure() {
388 let sql = "SELECT * FROM patient";
389 let result = pg_query::parse(sql).unwrap();
390
391 for stmt in &result.protobuf.stmts {
393 if let Some(s) = &stmt.stmt {
394 if let Some(pg_query::NodeEnum::SelectStmt(select)) = &s.node {
395 eprintln!("Target list length: {}", select.target_list.len());
396 for (i, target) in select.target_list.iter().enumerate() {
397 eprintln!("Target {}: {:?}", i, target.node);
398 }
399 }
400 }
401 }
402 }
403
404 #[test]
405 fn test_simple_select() {
406 let cache = create_test_cache();
407 let sql = "SELECT foo, bar FROM patient WHERE bar = 123";
408 let refs = extract_column_refs(sql, Some(&cache)).unwrap();
409
410 assert_eq!(refs.len(), 2);
411 assert!(refs.contains(&ColumnRef {
412 schema: "public".into(),
413 table: "patient".into(),
414 column: "foo".into()
415 }));
416 assert!(refs.contains(&ColumnRef {
417 schema: "public".into(),
418 table: "patient".into(),
419 column: "bar".into()
420 }));
421 }
422
423 #[test]
424 fn test_select_with_expression() {
425 let cache = create_test_cache();
426 let sql = "SELECT bar, foo + 2 FROM patient";
427 let refs = extract_column_refs(sql, Some(&cache)).unwrap();
428
429 assert_eq!(refs.len(), 1);
431 assert!(refs.contains(&ColumnRef {
432 schema: "public".into(),
433 table: "patient".into(),
434 column: "bar".into()
435 }));
436 }
437
438 #[test]
439 fn test_select_star() {
440 let cache = create_test_cache();
441 let sql = "SELECT * FROM patient";
442 let refs = extract_column_refs(sql, Some(&cache)).unwrap();
443
444 assert_eq!(refs.len(), 3);
445 assert!(refs.contains(&ColumnRef {
446 schema: "public".into(),
447 table: "patient".into(),
448 column: "foo".into()
449 }));
450 assert!(refs.contains(&ColumnRef {
451 schema: "public".into(),
452 table: "patient".into(),
453 column: "bar".into()
454 }));
455 assert!(refs.contains(&ColumnRef {
456 schema: "public".into(),
457 table: "patient".into(),
458 column: "baz".into()
459 }));
460 }
461
462 #[test]
463 fn test_select_qualified_columns() {
464 let cache = create_test_cache();
465 let sql = "SELECT patient.foo, patient.bar FROM patient";
466 let refs = extract_column_refs(sql, Some(&cache)).unwrap();
467
468 assert_eq!(refs.len(), 2);
469 assert!(refs.contains(&ColumnRef {
470 schema: "public".into(),
471 table: "patient".into(),
472 column: "foo".into()
473 }));
474 assert!(refs.contains(&ColumnRef {
475 schema: "public".into(),
476 table: "patient".into(),
477 column: "bar".into()
478 }));
479 }
480}