bestool_psql/
column_extractor.rs

1use std::collections::HashSet;
2
3use miette::Result;
4use pg_query::{NodeEnum, parse};
5use tracing::debug;
6
7use crate::schema_cache::SchemaCache;
8
9/// A tuple representing (schema, table, column)
10#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
11pub struct ColumnRef {
12	pub schema: String,
13	pub table: String,
14	pub column: String,
15}
16
17/// Extract column references from a SQL query
18pub fn extract_column_refs(
19	sql: &str,
20	schema_cache: Option<&SchemaCache>,
21) -> Result<Vec<ColumnRef>> {
22	// Parse the SQL using pg_query
23	let parse_result = match parse(sql) {
24		Ok(result) => result,
25		Err(e) => {
26			debug!("Failed to parse SQL for column extraction: {}", e);
27			return Ok(Vec::new());
28		}
29	};
30
31	let mut column_refs = Vec::new();
32	let mut context = ExtractionContext {
33		schema_cache,
34		column_refs: &mut column_refs,
35		table_aliases: Default::default(),
36		in_select_list: false,
37	};
38
39	// Process each statement in the parse tree
40	for statement in parse_result.protobuf.stmts {
41		if let Some(stmt) = statement.stmt {
42			process_node(&stmt.node, &mut context);
43		}
44	}
45
46	// Deduplicate while preserving order
47	let mut seen = HashSet::new();
48	column_refs.retain(|col_ref| seen.insert(col_ref.clone()));
49
50	Ok(column_refs)
51}
52
53struct ExtractionContext<'a> {
54	schema_cache: Option<&'a SchemaCache>,
55	column_refs: &'a mut Vec<ColumnRef>,
56	table_aliases: std::collections::HashMap<String, (String, String)>,
57	in_select_list: bool,
58}
59
60fn process_node(node: &Option<NodeEnum>, ctx: &mut ExtractionContext<'_>) {
61	let Some(node) = node else { return };
62
63	match node {
64		NodeEnum::SelectStmt(select) => {
65			// First, process FROM clause to build table aliases
66			for from_item in &select.from_clause {
67				if let Some(NodeEnum::RangeVar(range)) = &from_item.node {
68					let table_name = range.relname.clone();
69					let schema_name = if range.schemaname.is_empty() {
70						if let Some(cache) = ctx.schema_cache {
71							find_schema_for_table(cache, &table_name)
72								.unwrap_or_else(|| "public".to_string())
73						} else {
74							"public".to_string()
75						}
76					} else {
77						range.schemaname.clone()
78					};
79
80					let alias = if let Some(a) = &range.alias {
81						a.aliasname.clone()
82					} else {
83						table_name.clone()
84					};
85
86					ctx.table_aliases.insert(alias, (schema_name, table_name));
87				}
88				// Process other types of FROM items (subqueries, joins, etc)
89				process_from_item(&from_item.node, ctx);
90			}
91
92			// Process target list (SELECT items)
93			let old_in_select_list = ctx.in_select_list;
94			ctx.in_select_list = true;
95
96			for target in &select.target_list {
97				if let Some(NodeEnum::ResTarget(res)) = &target.node
98					&& let Some(val) = &res.val
99				{
100					// Check if this is a simple ColumnRef (not a computed expression)
101					if let Some(NodeEnum::ColumnRef(_)) = &val.node {
102						process_node(&val.node, ctx);
103					} else if let Some(NodeEnum::AStar(_)) = &val.node {
104						// SELECT * - expand to all columns
105						expand_wildcard(None, ctx);
106					}
107					// For other expressions (computed columns), we don't extract
108				}
109			}
110
111			ctx.in_select_list = old_in_select_list;
112
113			// Process WHERE clause
114			if let Some(where_clause) = &select.where_clause {
115				process_node(&where_clause.node, ctx);
116			}
117
118			// Process GROUP BY
119			for group in &select.group_clause {
120				process_node(&group.node, ctx);
121			}
122
123			// Process HAVING
124			if let Some(having) = &select.having_clause {
125				process_node(&having.node, ctx);
126			}
127		}
128		NodeEnum::ColumnRef(col_ref) => {
129			// Extract column reference
130			process_column_ref(col_ref, ctx);
131		}
132		NodeEnum::AStar(_) => {
133			// SELECT * or table.*
134			expand_wildcard(None, ctx);
135		}
136		NodeEnum::RangeVar(_) => {
137			// Already handled in FROM processing
138		}
139		NodeEnum::JoinExpr(join) => {
140			// Process both sides of the join
141			if let Some(larg) = &join.larg {
142				process_node(&larg.node, ctx);
143			}
144			if let Some(rarg) = &join.rarg {
145				process_node(&rarg.node, ctx);
146			}
147			// Process join condition
148			if let Some(quals) = &join.quals {
149				process_node(&quals.node, ctx);
150			}
151		}
152		NodeEnum::AExpr(expr) => {
153			// Binary/unary expressions - process operands but don't mark as direct refs
154			let old_in_select_list = ctx.in_select_list;
155			ctx.in_select_list = false;
156
157			if let Some(lexpr) = &expr.lexpr {
158				process_node(&lexpr.node, ctx);
159			}
160			if let Some(rexpr) = &expr.rexpr {
161				process_node(&rexpr.node, ctx);
162			}
163
164			ctx.in_select_list = old_in_select_list;
165		}
166		NodeEnum::BoolExpr(expr) => {
167			// Boolean expressions (AND, OR, NOT)
168			for arg in &expr.args {
169				process_node(&arg.node, ctx);
170			}
171		}
172		NodeEnum::FuncCall(_) => {
173			// Function calls are computed expressions, don't extract columns from them
174			// even though they might reference columns
175		}
176		NodeEnum::SubLink(sublink) => {
177			// Subquery - process it
178			if let Some(subselect) = &sublink.subselect {
179				process_node(&subselect.node, ctx);
180			}
181		}
182		NodeEnum::RangeSubselect(range_sub) => {
183			// Subquery in FROM clause
184			if let Some(subquery) = &range_sub.subquery {
185				process_node(&subquery.node, ctx);
186			}
187		}
188		_ => {
189			// For other node types, we don't need to extract columns
190		}
191	}
192}
193
194fn process_from_item(node: &Option<NodeEnum>, ctx: &mut ExtractionContext<'_>) {
195	let Some(node) = node else { return };
196
197	match node {
198		NodeEnum::RangeVar(range) => {
199			let table_name = range.relname.clone();
200			let schema_name = if range.schemaname.is_empty() {
201				if let Some(cache) = ctx.schema_cache {
202					find_schema_for_table(cache, &table_name)
203						.unwrap_or_else(|| "public".to_string())
204				} else {
205					"public".to_string()
206				}
207			} else {
208				range.schemaname.clone()
209			};
210
211			let alias = if let Some(a) = &range.alias {
212				a.aliasname.clone()
213			} else {
214				table_name.clone()
215			};
216
217			ctx.table_aliases.insert(alias, (schema_name, table_name));
218		}
219		NodeEnum::JoinExpr(join) => {
220			if let Some(larg) = &join.larg {
221				process_from_item(&larg.node, ctx);
222			}
223			if let Some(rarg) = &join.rarg {
224				process_from_item(&rarg.node, ctx);
225			}
226		}
227		NodeEnum::RangeSubselect(_) => {
228			// Subquery - we could track this but for now skip
229		}
230		_ => {}
231	}
232}
233
234fn process_column_ref(col_ref: &pg_query::protobuf::ColumnRef, ctx: &mut ExtractionContext<'_>) {
235	let fields: Vec<String> = col_ref
236		.fields
237		.iter()
238		.filter_map(|field| {
239			if let Some(NodeEnum::String(s)) = &field.node {
240				Some(s.sval.clone())
241			} else if let Some(NodeEnum::AStar(_)) = &field.node {
242				None // Handle * separately
243			} else {
244				None
245			}
246		})
247		.collect();
248
249	// Check if this is a wildcard (table.* or just *)
250	let has_star = col_ref
251		.fields
252		.iter()
253		.any(|field| matches!(&field.node, Some(NodeEnum::AStar(_))));
254
255	if has_star {
256		if !fields.is_empty() {
257			// table.* case
258			let table_name = &fields[0];
259			expand_wildcard(Some(table_name), ctx);
260		} else {
261			// SELECT * case (unqualified wildcard)
262			expand_wildcard(None, ctx);
263		}
264		return;
265	}
266
267	match fields.len() {
268		1 => {
269			// Simple column reference (no table qualifier)
270			let column_name = &fields[0];
271
272			// If there's only one table, use it
273			if ctx.table_aliases.len() == 1
274				&& let Some((schema, table)) = ctx.table_aliases.values().next()
275			{
276				ctx.column_refs.push(ColumnRef {
277					schema: schema.clone(),
278					table: table.clone(),
279					column: column_name.clone(),
280				});
281			}
282			// Otherwise, we can't determine which table without more analysis
283		}
284		2 => {
285			// table.column
286			let table_or_alias = &fields[0];
287			let column_name = &fields[1];
288
289			if let Some((schema, table)) = ctx.table_aliases.get(table_or_alias) {
290				ctx.column_refs.push(ColumnRef {
291					schema: schema.clone(),
292					table: table.clone(),
293					column: column_name.clone(),
294				});
295			}
296		}
297		3 => {
298			// schema.table.column
299			let schema = &fields[0];
300			let table = &fields[1];
301			let column = &fields[2];
302			ctx.column_refs.push(ColumnRef {
303				schema: schema.clone(),
304				table: table.clone(),
305				column: column.clone(),
306			});
307		}
308		_ => {}
309	}
310}
311
312fn expand_wildcard(table_qualifier: Option<&str>, ctx: &mut ExtractionContext<'_>) {
313	let Some(cache) = ctx.schema_cache else {
314		return;
315	};
316
317	if let Some(table_name) = table_qualifier {
318		// Expand table.*
319		if let Some((schema, table)) = ctx.table_aliases.get(table_name)
320			&& let Some(columns) = cache.columns_for_table(table)
321		{
322			for column in columns {
323				ctx.column_refs.push(ColumnRef {
324					schema: schema.clone(),
325					table: table.clone(),
326					column: column.clone(),
327				});
328			}
329		}
330	} else {
331		// Expand * - all columns from all tables
332		for (schema, table) in ctx.table_aliases.values() {
333			if let Some(columns) = cache.columns_for_table(table) {
334				for column in columns {
335					ctx.column_refs.push(ColumnRef {
336						schema: schema.clone(),
337						table: table.clone(),
338						column: column.clone(),
339					});
340				}
341			}
342		}
343	}
344}
345
346fn find_schema_for_table(cache: &SchemaCache, table: &str) -> Option<String> {
347	// First check if it exists in public schema
348	if cache
349		.tables
350		.get("public")
351		.is_some_and(|tables| tables.contains(&table.to_string()))
352	{
353		return Some("public".to_string());
354	}
355
356	// Check other schemas
357	for (schema_name, tables) in &cache.tables {
358		if tables.contains(&table.to_string()) {
359			return Some(schema_name.clone());
360		}
361	}
362
363	None
364}
365
366#[cfg(test)]
367mod tests {
368	use super::*;
369
370	fn create_test_cache() -> SchemaCache {
371		let mut cache = SchemaCache::new();
372		cache
373			.tables
374			.insert("public".to_string(), vec!["patient".to_string()]);
375		cache.columns.insert(
376			"public.patient".to_string(),
377			vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
378		);
379		cache.columns.insert(
380			"patient".to_string(),
381			vec!["foo".to_string(), "bar".to_string(), "baz".to_string()],
382		);
383		cache
384	}
385
386	#[test]
387	fn test_parse_structure() {
388		let sql = "SELECT * FROM patient";
389		let result = pg_query::parse(sql).unwrap();
390
391		// Print the structure to understand it
392		for stmt in &result.protobuf.stmts {
393			if let Some(s) = &stmt.stmt {
394				if let Some(pg_query::NodeEnum::SelectStmt(select)) = &s.node {
395					eprintln!("Target list length: {}", select.target_list.len());
396					for (i, target) in select.target_list.iter().enumerate() {
397						eprintln!("Target {}: {:?}", i, target.node);
398					}
399				}
400			}
401		}
402	}
403
404	#[test]
405	fn test_simple_select() {
406		let cache = create_test_cache();
407		let sql = "SELECT foo, bar FROM patient WHERE bar = 123";
408		let refs = extract_column_refs(sql, Some(&cache)).unwrap();
409
410		assert_eq!(refs.len(), 2);
411		assert!(refs.contains(&ColumnRef {
412			schema: "public".into(),
413			table: "patient".into(),
414			column: "foo".into()
415		}));
416		assert!(refs.contains(&ColumnRef {
417			schema: "public".into(),
418			table: "patient".into(),
419			column: "bar".into()
420		}));
421	}
422
423	#[test]
424	fn test_select_with_expression() {
425		let cache = create_test_cache();
426		let sql = "SELECT bar, foo + 2 FROM patient";
427		let refs = extract_column_refs(sql, Some(&cache)).unwrap();
428
429		// Should only return 'bar', not 'foo' because it's part of an expression
430		assert_eq!(refs.len(), 1);
431		assert!(refs.contains(&ColumnRef {
432			schema: "public".into(),
433			table: "patient".into(),
434			column: "bar".into()
435		}));
436	}
437
438	#[test]
439	fn test_select_star() {
440		let cache = create_test_cache();
441		let sql = "SELECT * FROM patient";
442		let refs = extract_column_refs(sql, Some(&cache)).unwrap();
443
444		assert_eq!(refs.len(), 3);
445		assert!(refs.contains(&ColumnRef {
446			schema: "public".into(),
447			table: "patient".into(),
448			column: "foo".into()
449		}));
450		assert!(refs.contains(&ColumnRef {
451			schema: "public".into(),
452			table: "patient".into(),
453			column: "bar".into()
454		}));
455		assert!(refs.contains(&ColumnRef {
456			schema: "public".into(),
457			table: "patient".into(),
458			column: "baz".into()
459		}));
460	}
461
462	#[test]
463	fn test_select_qualified_columns() {
464		let cache = create_test_cache();
465		let sql = "SELECT patient.foo, patient.bar FROM patient";
466		let refs = extract_column_refs(sql, Some(&cache)).unwrap();
467
468		assert_eq!(refs.len(), 2);
469		assert!(refs.contains(&ColumnRef {
470			schema: "public".into(),
471			table: "patient".into(),
472			column: "foo".into()
473		}));
474		assert!(refs.contains(&ColumnRef {
475			schema: "public".into(),
476			table: "patient".into(),
477			column: "bar".into()
478		}));
479	}
480}