1use std::cmp::Ordering;
6use std::collections::hash_map::Entry;
7use std::collections::{HashMap, HashSet};
8use std::str::FromStr;
9use std::sync::{Arc, Mutex, OnceLock};
10
11use arrow::compute::kernels::cast_utils::IntervalUnit;
12use arrow_schema::{DataType, Schema, SchemaRef};
13use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
14use datafusion::common::{
15 Column, Constraints, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, TableReference,
16};
17use datafusion::datasource::listing::{ListingTable, ListingTableConfig};
18use datafusion::datasource::{TableProvider, provider_as_source, source_as_provider};
19use datafusion::logical_expr::logical_plan::{
20 Aggregate, Distinct, DistinctOn, Filter, Join, Limit, LogicalPlan, Partitioning, Projection,
21 Repartition, Sort, SubqueryAlias, TableScan, Union, Window,
22};
23use datafusion::logical_expr::{Expr, TableSource};
24use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
25use datafusion::physical_expr_adapter::{
26 DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
27};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub(crate) enum SupportedIntervalUnit {
32 Year,
33 Month,
34 Day,
35 DayOfWeek,
36}
37
38impl SupportedIntervalUnit {
39 pub(crate) fn metadata_value(self) -> &'static str {
40 match self {
41 SupportedIntervalUnit::Year => "YEAR",
42 SupportedIntervalUnit::Month => "MONTH",
43 SupportedIntervalUnit::Day => "DAY",
44 SupportedIntervalUnit::DayOfWeek => "DOW",
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub(crate) struct DateExtraction {
52 pub(crate) column: Column,
53 pub(crate) components: HashSet<SupportedIntervalUnit>,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub(crate) struct VariantExtraction {
59 pub(crate) column: Column,
60 pub(crate) fields: Vec<VariantField>,
61}
62
63impl PartialOrd for VariantExtraction {
64 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
65 Some(self.cmp(other))
66 }
67}
68
69impl Ord for VariantExtraction {
70 fn cmp(&self, other: &Self) -> Ordering {
71 self.column
72 .flat_name()
73 .cmp(&other.column.flat_name())
74 .then_with(|| self.fields.cmp(&other.fields))
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub(crate) struct VariantField {
80 pub(crate) path: String,
81 pub(crate) data_type: Option<DataType>,
82}
83
84impl PartialOrd for VariantField {
85 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
86 Some(self.cmp(other))
87 }
88}
89
90impl Ord for VariantField {
91 fn cmp(&self, other: &Self) -> Ordering {
92 self.path.cmp(&other.path).then_with(|| {
93 let self_ty = self
94 .data_type
95 .as_ref()
96 .map(|dt| dt.to_string())
97 .unwrap_or_default();
98 let other_ty = other
99 .data_type
100 .as_ref()
101 .map(|dt| dt.to_string())
102 .unwrap_or_default();
103 self_ty.cmp(&other_ty)
104 })
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
110pub(crate) enum ColumnAnnotation {
111 DatePart(HashSet<SupportedIntervalUnit>),
112 VariantPaths(Vec<VariantField>),
113 SubstringSearch,
114}
115
116pub(crate) fn serialize_date_part(units: &HashSet<SupportedIntervalUnit>) -> String {
117 let mut sorted_units: Vec<&SupportedIntervalUnit> = units.iter().collect();
118 sorted_units.sort_by_key(|unit| match unit {
120 SupportedIntervalUnit::Year => 0,
121 SupportedIntervalUnit::Month => 1,
122 SupportedIntervalUnit::Day => 2,
123 SupportedIntervalUnit::DayOfWeek => 3,
124 });
125 sorted_units
126 .iter()
127 .map(|unit| unit.metadata_value())
128 .collect::<Vec<_>>()
129 .join(",")
130}
131
132#[derive(Debug, Default)]
135pub struct LineageOptimizer;
136
137impl LineageOptimizer {
138 pub fn new() -> Self {
140 Self
141 }
142}
143
144impl OptimizerRule for LineageOptimizer {
145 fn name(&self) -> &str {
146 "LineageOptimizer"
147 }
148
149 fn apply_order(&self) -> Option<ApplyOrder> {
150 None
152 }
153
154 fn rewrite(
155 &self,
156 plan: LogicalPlan,
157 _config: &dyn OptimizerConfig,
158 ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
159 let mut analyzer = LineageAnalyzer::default();
160 let _ = analyzer.analyze_plan(&plan)?;
161 let table_usage = analyzer.finish();
162 let mut date_findings = table_usage.find_date32_extractions();
163 date_findings.sort_by(|a, b| a.column.flat_name().cmp(&b.column.flat_name()));
164
165 let mut variant_findings = table_usage.find_variant_gets();
166 variant_findings.sort();
167
168 let mut substring_findings = table_usage.find_substring_searches();
169 substring_findings.sort_by_key(|a| a.flat_name());
170
171 let annotations =
172 build_annotation_map(&date_findings, &variant_findings, &substring_findings);
173 annotate_plan_with_extractions(plan, &annotations)
174 }
175}
176
177type LineageMap = HashMap<ColumnKey, Vec<ColumnUsage>>;
178
179#[derive(Clone, Debug, PartialEq, Eq, Hash)]
180struct ColumnKey {
181 relation: Option<TableReference>,
182 name: String,
183}
184
185impl ColumnKey {
186 fn new(relation: Option<TableReference>, name: impl Into<String>) -> Self {
187 Self {
188 relation,
189 name: name.into(),
190 }
191 }
192
193 fn from_column(column: &Column) -> Self {
194 Self {
195 relation: column.relation.clone(),
196 name: column.name.clone(),
197 }
198 }
199
200 fn to_column(&self) -> Column {
201 Column::new(self.relation.clone(), self.name.clone())
202 }
203}
204
205#[derive(Debug, Clone)]
206struct ColumnUsage {
207 base: ColumnKey,
208 data_type: DataType,
209 operations: Vec<Operation>,
210}
211
212impl ColumnUsage {
213 fn new_base(column: &Column, data_type: DataType) -> Self {
214 Self {
215 base: ColumnKey::from_column(column),
216 data_type,
217 operations: Vec::new(),
218 }
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
223enum Operation {
224 Extract(SupportedIntervalUnit),
225 VariantGet {
226 path: String,
227 data_type: Option<DataType>,
228 },
229 SubstringSearch,
230 Other,
231}
232
233#[derive(Debug)]
234struct UsageStats {
235 data_type: DataType,
236 usages: Vec<Vec<Operation>>,
237}
238
239impl UsageStats {
240 fn new(data_type: DataType) -> Self {
241 Self {
242 data_type,
243 usages: Vec::new(),
244 }
245 }
246
247 fn apply(&mut self, usage: &ColumnUsage) {
248 self.usages.push(usage.operations.clone());
249 }
250}
251
252#[derive(Default)]
253struct LineageAnalyzer {
254 stats: HashMap<ColumnKey, UsageStats>,
255}
256
257impl LineageAnalyzer {
258 fn analyze_plan(&mut self, plan: &LogicalPlan) -> Result<LineageMap> {
259 match plan {
260 LogicalPlan::TableScan(scan) => self.analyze_table_scan(scan),
261 LogicalPlan::Projection(projection) => self.analyze_projection(projection),
262 LogicalPlan::Filter(filter) => self.analyze_filter(filter),
263 LogicalPlan::Aggregate(aggregate) => self.analyze_aggregate(aggregate),
264 LogicalPlan::Sort(sort) => self.analyze_sort(sort),
265 LogicalPlan::Join(join) => self.analyze_join(join),
266 LogicalPlan::SubqueryAlias(alias) => self.analyze_subquery_alias(alias),
267 LogicalPlan::Window(window) => self.analyze_window(window),
268 LogicalPlan::Limit(limit) => self.analyze_limit(limit),
269 LogicalPlan::Repartition(repartition) => self.analyze_repartition(repartition),
270 LogicalPlan::Union(union) => self.analyze_union(union),
271 LogicalPlan::Distinct(distinct) => self.analyze_distinct(distinct),
272 other => {
273 let mut merged = LineageMap::new();
274 for input in other.inputs() {
275 let child = self.analyze_plan(input)?;
276 merged = merge_lineage_maps(merged, child);
277 }
278 Ok(merged)
279 }
280 }
281 }
282
283 fn analyze_table_scan(&mut self, scan: &TableScan) -> Result<LineageMap> {
284 let schema = scan.projected_schema.as_ref();
285 let mut map = LineageMap::new();
286 for (index, column) in schema.columns().iter().enumerate() {
287 let field = schema.field(index);
288 let usage = ColumnUsage::new_base(column, field.data_type().clone());
289 map.insert(usage.base.clone(), vec![usage]);
290 }
291
292 for filter in &scan.filters {
293 let usages = lineage_for_expr(filter, &map, schema)?;
294 self.record(&usages);
295 }
296
297 Ok(map)
298 }
299
300 fn analyze_projection(&mut self, projection: &Projection) -> Result<LineageMap> {
301 let input_map = self.analyze_plan(projection.input.as_ref())?;
302 let input_schema = projection.input.schema();
303 let mut output = LineageMap::new();
304 for (expr, column) in projection.expr.iter().zip(projection.schema.columns()) {
305 let usages = lineage_for_expr(expr, &input_map, input_schema.as_ref())?;
306 self.record(&usages);
307 output.insert(ColumnKey::from_column(&column), usages);
308 }
309 Ok(output)
310 }
311
312 fn analyze_filter(&mut self, filter: &Filter) -> Result<LineageMap> {
313 let input_map = self.analyze_plan(filter.input.as_ref())?;
314 let schema = filter.input.schema();
315 let usages = lineage_for_expr(&filter.predicate, &input_map, schema.as_ref())?;
316 self.record(&usages);
317 Ok(input_map)
318 }
319
320 fn analyze_sort(&mut self, sort: &Sort) -> Result<LineageMap> {
321 let input_map = self.analyze_plan(sort.input.as_ref())?;
322 let schema = sort.input.schema();
323 for sort_expr in &sort.expr {
324 let usages = lineage_for_expr(&sort_expr.expr, &input_map, schema.as_ref())?;
325 self.record(&usages);
326 }
327 Ok(input_map)
328 }
329
330 fn analyze_aggregate(&mut self, aggregate: &Aggregate) -> Result<LineageMap> {
331 let input_map = self.analyze_plan(aggregate.input.as_ref())?;
332 let schema = aggregate.input.schema();
333 let mut output = LineageMap::new();
334 let mut expr_iter = aggregate
335 .group_expr
336 .iter()
337 .chain(aggregate.aggr_expr.iter());
338
339 for column in aggregate.schema.columns() {
340 if let Some(expr) = expr_iter.next() {
341 let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
342 self.record(&usages);
343 output.insert(ColumnKey::from_column(&column), usages);
344 } else {
345 output.insert(ColumnKey::from_column(&column), Vec::new());
346 }
347 }
348
349 Ok(output)
350 }
351
352 fn analyze_join(&mut self, join: &Join) -> Result<LineageMap> {
353 let left_map = self.analyze_plan(join.left.as_ref())?;
354 let right_map = self.analyze_plan(join.right.as_ref())?;
355 let left_schema = join.left.schema();
356 let right_schema = join.right.schema();
357
358 for (left_expr, right_expr) in &join.on {
359 let left_usages = lineage_for_expr(left_expr, &left_map, left_schema.as_ref())?;
360 self.record(&left_usages);
361 let right_usages = lineage_for_expr(right_expr, &right_map, right_schema.as_ref())?;
362 self.record(&right_usages);
363 }
364
365 if let Some(filter) = &join.filter {
366 let mut combined = left_map.clone();
367 merge_lineage_map_inplace(&mut combined, &right_map);
368 let usages = lineage_for_expr(filter, &combined, join.schema.as_ref())?;
369 self.record(&usages);
370 }
371
372 let left_columns = left_schema.columns();
373 let right_columns = right_schema.columns();
374 let mut output_columns = join.schema.columns().into_iter();
375 let mut output = LineageMap::new();
376
377 for column in left_columns {
378 if let Some(output_column) = output_columns.next() {
379 let key = ColumnKey::from_column(&output_column);
380 let usages = left_map
381 .get(&ColumnKey::from_column(&column))
382 .cloned()
383 .unwrap_or_default();
384 output.insert(key, usages);
385 }
386 }
387
388 for column in right_columns {
389 if let Some(output_column) = output_columns.next() {
390 let key = ColumnKey::from_column(&output_column);
391 let usages = right_map
392 .get(&ColumnKey::from_column(&column))
393 .cloned()
394 .unwrap_or_default();
395 output.insert(key, usages);
396 }
397 }
398
399 Ok(output)
400 }
401
402 fn analyze_subquery_alias(&mut self, alias: &SubqueryAlias) -> Result<LineageMap> {
403 let input_map = self.analyze_plan(alias.input.as_ref())?;
404 let input_columns = alias.input.schema().columns();
405 let mut output = LineageMap::new();
406 for (input_column, output_column) in
407 input_columns.iter().zip(alias.schema.columns().into_iter())
408 {
409 let key = ColumnKey::from_column(&output_column);
410 let usages = input_map
411 .get(&ColumnKey::from_column(input_column))
412 .cloned()
413 .unwrap_or_default();
414 output.insert(key, usages);
415 }
416 Ok(output)
417 }
418
419 fn analyze_window(&mut self, window: &Window) -> Result<LineageMap> {
420 let input_map = self.analyze_plan(window.input.as_ref())?;
421 let input_schema = window.input.schema();
422
423 let input_cols = input_schema.columns();
424 let output_cols = window.schema.columns();
425 let mut output = LineageMap::new();
426
427 for (input_column, output_column) in input_cols.iter().zip(output_cols.iter()) {
428 let key = ColumnKey::from_column(output_column);
429 let usages = input_map
430 .get(&ColumnKey::from_column(input_column))
431 .cloned()
432 .unwrap_or_default();
433 output.insert(key, usages);
434 }
435
436 for (expr, output_column) in window
437 .window_expr
438 .iter()
439 .zip(output_cols.into_iter().skip(input_cols.len()))
440 {
441 let usages = lineage_for_expr(expr, &input_map, input_schema.as_ref())?;
442 self.record(&usages);
443 output.insert(ColumnKey::from_column(&output_column), usages);
444 }
445
446 Ok(output)
447 }
448
449 fn analyze_limit(&mut self, limit: &Limit) -> Result<LineageMap> {
450 let map = self.analyze_plan(limit.input.as_ref())?;
451 let schema = limit.input.schema();
452 if let Some(skip) = &limit.skip {
453 let usages = lineage_for_expr(skip, &map, schema.as_ref())?;
454 self.record(&usages);
455 }
456 if let Some(fetch) = &limit.fetch {
457 let usages = lineage_for_expr(fetch, &map, schema.as_ref())?;
458 self.record(&usages);
459 }
460 Ok(map)
461 }
462
463 fn analyze_repartition(&mut self, repartition: &Repartition) -> Result<LineageMap> {
464 let map = self.analyze_plan(repartition.input.as_ref())?;
465 let schema = repartition.input.schema();
466 if let Partitioning::Hash(exprs, _) | Partitioning::DistributeBy(exprs) =
467 &repartition.partitioning_scheme
468 {
469 for expr in exprs {
470 let usages = lineage_for_expr(expr, &map, schema.as_ref())?;
471 self.record(&usages);
472 }
473 }
474 Ok(map)
475 }
476
477 fn analyze_union(&mut self, union: &Union) -> Result<LineageMap> {
478 let mut input_maps: Vec<LineageMap> = Vec::with_capacity(union.inputs.len());
479 for input in &union.inputs {
480 input_maps.push(self.analyze_plan(input.as_ref())?);
481 }
482
483 let mut output = LineageMap::new();
484 for output_column in union.schema.columns() {
485 let key = ColumnKey::from_column(&output_column);
486 let mut combined: Vec<ColumnUsage> = Vec::new();
487 for map in &input_maps {
488 for (candidate_key, usages) in map {
489 if candidate_key.name == key.name {
490 combined.extend(usages.clone());
491 }
492 }
493 }
494 output.insert(key, combined);
495 }
496 Ok(output)
497 }
498
499 fn analyze_distinct(&mut self, distinct: &Distinct) -> Result<LineageMap> {
500 match distinct {
501 Distinct::All(plan) => self.analyze_plan(plan.as_ref()),
502 Distinct::On(distinct_on) => self.analyze_distinct_on(distinct_on),
503 }
504 }
505
506 fn analyze_distinct_on(&mut self, distinct_on: &DistinctOn) -> Result<LineageMap> {
507 let input_map = self.analyze_plan(distinct_on.input.as_ref())?;
508 let schema = distinct_on.input.schema();
509
510 for expr in &distinct_on.on_expr {
511 let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
512 self.record(&usages);
513 }
514 for expr in &distinct_on.select_expr {
515 let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
516 self.record(&usages);
517 }
518 if let Some(sort_exprs) = &distinct_on.sort_expr {
519 for sort_expr in sort_exprs {
520 let usages = lineage_for_expr(&sort_expr.expr, &input_map, schema.as_ref())?;
521 self.record(&usages);
522 }
523 }
524
525 let mut output = LineageMap::new();
526 for (expr, column) in distinct_on
527 .select_expr
528 .iter()
529 .zip(distinct_on.schema.columns().into_iter())
530 {
531 let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
532 output.insert(ColumnKey::from_column(&column), usages);
533 }
534 Ok(output)
535 }
536
537 fn record(&mut self, usages: &[ColumnUsage]) {
538 for usage in usages {
539 let entry = self
540 .stats
541 .entry(usage.base.clone())
542 .or_insert_with(|| UsageStats::new(usage.data_type.clone()));
543 entry.apply(usage);
544 }
545 }
546
547 fn finish(self) -> TableColumnUsage {
548 TableColumnUsage { usage: self.stats }
549 }
550}
551
552struct TableColumnUsage {
553 usage: HashMap<ColumnKey, UsageStats>,
554}
555
556impl TableColumnUsage {
557 fn find_date32_extractions(&self) -> Vec<DateExtraction> {
558 let mut extractions = Vec::new();
559 for (key, stats) in self.usage.iter() {
560 if is_date_part_type(&stats.data_type) {
561 let mut all_units = HashSet::new();
563 let mut all_paths_valid = true;
564
565 for usage in &stats.usages {
566 let mut path_units = HashSet::new();
568 for op in usage {
569 match op {
570 Operation::Extract(unit) => {
571 path_units.insert(unit);
572 }
573 _ => {
574 break;
576 }
577 }
578 }
579
580 if path_units.is_empty() {
581 all_paths_valid = false;
583 break;
584 }
585
586 all_units.extend(path_units);
588 }
589
590 if all_paths_valid && !all_units.is_empty() {
591 extractions.push(DateExtraction {
592 column: key.to_column(),
593 components: all_units,
594 });
595 }
596 }
597 }
598 extractions
599 }
600
601 fn find_variant_gets(&self) -> Vec<VariantExtraction> {
602 let mut gets = Vec::new();
603 for (key, stats) in self.usage.iter() {
604 if stats.usages.is_empty() {
605 continue;
606 }
607
608 let mut field_map: HashMap<String, VariantField> = HashMap::new();
609 let mut valid = true;
610 let mut saw_variant_get = false;
611 for usage in &stats.usages {
612 match usage.first() {
613 Some(Operation::VariantGet { path, data_type }) => {
614 saw_variant_get = true;
615 match field_map.entry(path.clone()) {
616 Entry::Vacant(entry) => {
617 entry.insert(VariantField {
618 path: path.clone(),
619 data_type: data_type.clone(),
620 });
621 }
622 Entry::Occupied(entry) => {
623 let current = entry.into_mut();
624 let conflict = match (¤t.data_type, data_type) {
625 (Some(existing), Some(new_ty)) => existing != new_ty,
626 (Some(_), None) | (None, Some(_)) => true,
627 (None, None) => false,
628 };
629 if conflict {
630 valid = false;
631 break;
632 }
633 }
634 }
635 }
636 None => continue,
639 _ => {
640 valid = false;
641 break;
642 }
643 }
644 }
645
646 if valid && saw_variant_get && !field_map.is_empty() {
647 let mut fields: Vec<VariantField> = field_map.into_values().collect();
648 fields.sort();
649 gets.push(VariantExtraction {
650 column: key.to_column(),
651 fields,
652 });
653 }
654 }
655 gets
656 }
657
658 fn find_substring_searches(&self) -> Vec<Column> {
659 let mut columns = Vec::new();
660 for (key, stats) in self.usage.iter() {
661 if !is_string_type(&stats.data_type) {
662 continue;
663 }
664
665 if stats.usages.is_empty() {
666 continue;
667 }
668
669 let mut saw_substring = false;
670 let mut valid = true;
671 for usage in &stats.usages {
672 let has_substring = usage
673 .iter()
674 .any(|op| matches!(op, Operation::SubstringSearch));
675 if has_substring {
676 saw_substring = true;
677 continue;
678 }
679 if !usage.is_empty() {
680 valid = false;
681 break;
682 }
683 }
684
685 if valid && saw_substring {
686 columns.push(key.to_column());
687 }
688 }
689 columns
690 }
691}
692
693fn build_annotation_map(
694 date_findings: &[DateExtraction],
695 variant_findings: &[VariantExtraction],
696 substring_findings: &[Column],
697) -> HashMap<ColumnKey, ColumnAnnotation> {
698 let mut annotations: HashMap<ColumnKey, ColumnAnnotation> = HashMap::new();
699 for extraction in date_findings {
700 annotations.insert(
701 ColumnKey::from_column(&extraction.column),
702 ColumnAnnotation::DatePart(extraction.components.clone()),
703 );
704 }
705 for extraction in variant_findings {
706 annotations.insert(
707 ColumnKey::from_column(&extraction.column),
708 ColumnAnnotation::VariantPaths(extraction.fields.clone()),
709 );
710 }
711 for column in substring_findings {
712 annotations.insert(
713 ColumnKey::from_column(column),
714 ColumnAnnotation::SubstringSearch,
715 );
716 }
717 annotations
718}
719
720fn annotate_plan_with_extractions(
721 plan: LogicalPlan,
722 annotations: &HashMap<ColumnKey, ColumnAnnotation>,
723) -> Result<Transformed<LogicalPlan>, DataFusionError> {
724 if annotations.is_empty() {
725 return Ok(Transformed::no(plan));
726 }
727
728 plan.transform_up(|logical_plan| match logical_plan {
729 LogicalPlan::TableScan(mut scan) => {
730 let table_annotations = annotations_for_table_scan(&scan, annotations);
731 let mut changed = false;
732
733 if let Some(source) = annotate_listing_table_source(&scan.source, &table_annotations)? {
734 scan.source = source;
735 changed = true;
736 }
737
738 if changed {
739 Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
740 } else {
741 Ok(Transformed::no(LogicalPlan::TableScan(scan)))
742 }
743 }
744 other => Ok(Transformed::no(other)),
745 })
746}
747
748fn annotations_for_table_scan(
749 scan: &TableScan,
750 annotations: &HashMap<ColumnKey, ColumnAnnotation>,
751) -> HashMap<String, ColumnAnnotation> {
752 let mut table_annotations = HashMap::new();
753
754 for (qualifier_opt, field_ref) in scan.projected_schema.iter() {
755 let qualifier_owned = qualifier_opt.cloned();
756 let name = field_ref.name().clone();
757 if let Some(unit) = annotations
758 .get(&ColumnKey::new(qualifier_owned.clone(), name.clone()))
759 .cloned()
760 .or_else(|| {
761 annotations
762 .get(&ColumnKey::new(None, name.clone()))
763 .cloned()
764 })
765 {
766 table_annotations.insert(name, unit);
767 }
768 }
769
770 table_annotations
771}
772
773fn annotate_listing_table_source(
774 source: &Arc<dyn TableSource>,
775 annotations: &HashMap<String, ColumnAnnotation>,
776) -> Result<Option<Arc<dyn TableSource>>, DataFusionError> {
777 if annotations.is_empty() {
778 return Ok(None);
779 }
780
781 let provider = match source_as_provider(source) {
782 Ok(provider) => provider,
783 Err(_) => return Ok(None),
784 };
785
786 let Some(listing) = provider.as_any().downcast_ref::<ListingTable>() else {
787 return Ok(None);
788 };
789
790 let metadata_copy = annotations.clone();
791 let new_factory: Arc<dyn PhysicalExprAdapterFactory> = Arc::new(
792 LineageExtractPhysicalExprAdapterFactory::new(annotations.clone()),
793 );
794 register_factory_metadata(&new_factory, metadata_copy);
795 let mut new_listing = ListingTable::try_new(
796 ListingTableConfig::new_with_multi_paths(listing.table_paths().clone())
797 .with_listing_options(listing.options().clone())
798 .with_schema(listing_file_schema(listing))
799 .with_expr_adapter_factory(new_factory),
800 )?;
801 new_listing = new_listing.with_constraints(listing_constraints(listing));
802 new_listing = new_listing.with_definition(
803 listing
804 .get_table_definition()
805 .map(std::string::ToString::to_string),
806 );
807
808 let new_provider: Arc<dyn TableProvider> = Arc::new(new_listing);
809 Ok(Some(provider_as_source(new_provider)))
810}
811
812#[derive(Debug)]
813struct LineageExtractPhysicalExprAdapterFactory {
814 base: Arc<dyn PhysicalExprAdapterFactory>,
815 _annotations: HashMap<String, ColumnAnnotation>,
816}
817
818impl LineageExtractPhysicalExprAdapterFactory {
819 fn new(annotations: HashMap<String, ColumnAnnotation>) -> Self {
820 Self {
821 base: Arc::new(DefaultPhysicalExprAdapterFactory),
822 _annotations: annotations,
823 }
824 }
825}
826
827impl PhysicalExprAdapterFactory for LineageExtractPhysicalExprAdapterFactory {
828 fn create(
829 &self,
830 logical_file_schema: SchemaRef,
831 physical_file_schema: SchemaRef,
832 ) -> Arc<dyn PhysicalExprAdapter> {
833 self.base.create(logical_file_schema, physical_file_schema)
834 }
835}
836
837fn listing_file_schema(listing: &ListingTable) -> SchemaRef {
838 let table_schema = listing.schema();
839 let file_field_count = table_schema
840 .fields()
841 .len()
842 .saturating_sub(listing.options().table_partition_cols.len());
843 let fields = table_schema
844 .fields()
845 .iter()
846 .take(file_field_count)
847 .cloned()
848 .collect::<Vec<_>>();
849 Arc::new(Schema::new(fields).with_metadata(table_schema.metadata().clone()))
850}
851
852fn listing_constraints(listing: &ListingTable) -> Constraints {
853 listing.constraints().cloned().unwrap_or_default()
854}
855
856fn factory_registry() -> &'static Mutex<HashMap<usize, HashMap<String, ColumnAnnotation>>> {
857 static REGISTRY: OnceLock<Mutex<HashMap<usize, HashMap<String, ColumnAnnotation>>>> =
858 OnceLock::new();
859 REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
860}
861
862fn register_factory_metadata(
863 factory: &Arc<dyn PhysicalExprAdapterFactory>,
864 metadata: HashMap<String, ColumnAnnotation>,
865) {
866 let key = Arc::as_ptr(factory) as *const () as usize;
867 factory_registry().lock().unwrap().insert(key, metadata);
868}
869
870pub(crate) fn metadata_from_factory(
871 factory: &Arc<dyn PhysicalExprAdapterFactory>,
872 column: &str,
873) -> Option<ColumnAnnotation> {
874 let key = Arc::as_ptr(factory) as *const () as usize;
875 factory_registry()
876 .lock()
877 .unwrap()
878 .get(&key)
879 .and_then(|map| map.get(column).cloned())
880}
881
882fn merge_lineage_maps(mut base: LineageMap, other: LineageMap) -> LineageMap {
883 for (key, usages) in other {
884 base.entry(key).or_default().extend(usages);
885 }
886 base
887}
888
889fn merge_lineage_map_inplace(base: &mut LineageMap, other: &LineageMap) {
890 for (key, usages) in other {
891 base.entry(key.clone()).or_default().extend(usages.clone());
892 }
893}
894
895fn lineage_for_expr(
896 expr: &Expr,
897 input_lineage: &LineageMap,
898 schema: &DFSchema,
899) -> Result<Vec<ColumnUsage>> {
900 match expr {
901 Expr::Column(column) => {
902 let key = ColumnKey::from_column(column);
903 if let Some(usages) = input_lineage.get(&key) {
904 Ok(usages.clone())
905 } else {
906 let field = schema.field_from_column(column)?;
907 Ok(vec![ColumnUsage::new_base(
908 column,
909 field.data_type().clone(),
910 )])
911 }
912 }
913 Expr::Alias(alias) => lineage_for_expr(&alias.expr, input_lineage, schema),
914 Expr::ScalarFunction(func) => {
915 let func_name = func.func.name();
916 if func_name.eq_ignore_ascii_case("date_part")
917 && func.args.len() == 2
918 && let Some(component) = part_to_unit(&func.args[0])
919 {
920 let mut usages = lineage_for_expr(&func.args[1], input_lineage, schema)?;
921 for usage in &mut usages {
922 usage.operations.push(Operation::Extract(component));
923 }
924 return Ok(usages);
925 } else if func_name.eq_ignore_ascii_case("variant_get")
926 && (func.args.len() == 2 || func.args.len() == 3)
927 && let Some(path) = literal_utf8(&func.args[1])
928 {
929 let type_hint = func.args.get(2).and_then(literal_data_type);
930 let mut usages = lineage_for_expr(&func.args[0], input_lineage, schema)?;
931 for usage in &mut usages {
932 usage.operations.push(Operation::VariantGet {
933 path: path.clone(),
934 data_type: type_hint.clone(),
935 });
936 }
937 return Ok(usages);
938 }
939 propagate_other(expr, input_lineage, schema)
940 }
941 Expr::Like(like) => {
942 if !like.case_insensitive
943 && like.escape_char.is_none()
944 && let Some(pattern) = literal_utf8(&like.pattern)
945 && is_substring_pattern(pattern.as_bytes())
946 {
947 let mut usages = lineage_for_expr(&like.expr, input_lineage, schema)?;
948 for usage in &mut usages {
949 usage.operations.push(Operation::SubstringSearch);
950 }
951 return Ok(usages);
952 }
953 propagate_other(expr, input_lineage, schema)
954 }
955 Expr::Cast(cast) => {
956 let mut usages = lineage_for_expr(&cast.expr, input_lineage, schema)?;
957 for usage in &mut usages {
958 usage.operations.push(Operation::Other);
959 }
960 Ok(usages)
961 }
962 Expr::TryCast(cast) => {
963 let mut usages = lineage_for_expr(&cast.expr, input_lineage, schema)?;
964 for usage in &mut usages {
965 usage.operations.push(Operation::Other);
966 }
967 Ok(usages)
968 }
969 Expr::Literal(_, _) => Ok(Vec::new()),
970 Expr::ScalarSubquery(_) | Expr::Exists { .. } => Ok(Vec::new()),
971 Expr::Placeholder(_) => Ok(Vec::new()),
972 #[allow(deprecated)]
973 Expr::Wildcard { .. } => {
974 let mut usages = Vec::new();
975 for column_usages in input_lineage.values() {
976 usages.extend(column_usages.clone());
977 }
978 Ok(usages)
979 }
980 _ => propagate_other(expr, input_lineage, schema),
981 }
982}
983
984fn propagate_other(
985 expr: &Expr,
986 input_lineage: &LineageMap,
987 schema: &DFSchema,
988) -> Result<Vec<ColumnUsage>> {
989 let mut combined: Vec<ColumnUsage> = Vec::new();
990 expr.apply_children(|child| {
991 let mut usages = lineage_for_expr(child, input_lineage, schema)?;
992 for usage in &mut usages {
993 usage.operations.push(Operation::Other);
994 }
995 combined.extend(usages);
996 Ok(TreeNodeRecursion::Continue)
997 })?;
998 Ok(combined)
999}
1000
1001fn literal_utf8(expr: &Expr) -> Option<String> {
1002 match expr {
1003 Expr::Literal(value, _) => match value {
1004 ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => Some(v.clone()),
1005 ScalarValue::Utf8View(Some(v)) => Some(v.clone()),
1006 _ => None,
1007 },
1008 _ => None,
1009 }
1010}
1011
1012fn is_substring_pattern(pattern: &[u8]) -> bool {
1013 if pattern.len() < 2 {
1014 return false;
1015 }
1016 if pattern[0] != b'%' || pattern[pattern.len() - 1] != b'%' {
1017 return false;
1018 }
1019 let inner = &pattern[1..pattern.len() - 1];
1020 if inner.is_empty() {
1021 return false;
1022 }
1023 !inner.iter().any(|b| *b == b'%' || *b == b'_')
1024}
1025
1026fn is_string_type(data_type: &DataType) -> bool {
1027 match data_type {
1028 DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => true,
1029 DataType::Dictionary(_, value_type) => is_string_type(value_type.as_ref()),
1030 _ => false,
1031 }
1032}
1033
1034fn is_date_part_type(data_type: &DataType) -> bool {
1035 matches!(data_type, DataType::Date32 | DataType::Timestamp(_, _))
1036}
1037
1038fn literal_data_type(expr: &Expr) -> Option<DataType> {
1039 literal_utf8(expr).and_then(|spec| DataType::from_str(&spec).ok())
1040}
1041
1042fn part_to_unit(expr: &Expr) -> Option<SupportedIntervalUnit> {
1043 let value = match expr {
1044 Expr::Literal(literal, _) => literal,
1045 _ => return None,
1046 };
1047 let text = match value {
1048 ScalarValue::Utf8(Some(v))
1049 | ScalarValue::LargeUtf8(Some(v))
1050 | ScalarValue::Utf8View(Some(v)) => v.as_str(),
1051 _ => return None,
1052 };
1053 let lowered = text.to_ascii_lowercase();
1054 match lowered.as_str() {
1055 "dow" | "dayofweek" | "day_of_week" => {
1056 return Some(SupportedIntervalUnit::DayOfWeek);
1057 }
1058 _ => {}
1059 }
1060 let unit = IntervalUnit::from_str(text).ok()?;
1061 match unit {
1062 IntervalUnit::Year => Some(SupportedIntervalUnit::Year),
1063 IntervalUnit::Month => Some(SupportedIntervalUnit::Month),
1064 IntervalUnit::Day => Some(SupportedIntervalUnit::Day),
1065 _ => None,
1066 }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use std::path::PathBuf;
1072
1073 use crate::optimizers::{
1074 DATE_MAPPING_METADATA_KEY, LocalModeOptimizer, VARIANT_MAPPING_METADATA_KEY,
1075 };
1076 use crate::{LiquidCacheParquet, VariantGetUdf, VariantToJsonUdf};
1077 use liquid_cache::cache::AlwaysHydrate;
1078 use liquid_cache_common::IoMode;
1079
1080 use super::*;
1081 use arrow::array::{ArrayRef, Date32Array, StringArray, TimestampMicrosecondArray};
1082 use arrow_schema::{Field, Schema, TimeUnit};
1083 use datafusion::catalog::memory::DataSourceExec;
1084 use datafusion::datasource::physical_plan::FileScanConfig;
1085 use datafusion::execution::SessionStateBuilder;
1086 use datafusion::logical_expr::ScalarUDF;
1087 use datafusion::physical_plan::ExecutionPlan;
1088 use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
1089 use liquid_cache::cache::squeeze_policies::TranscodeSqueezeEvict;
1090 use liquid_cache::cache_policies::LiquidPolicy;
1091 use parquet::arrow::ArrowWriter;
1092 use parquet::variant::{VariantArray, json_to_variant};
1093 use serde::Deserialize;
1094 use tempfile::TempDir;
1095
1096 fn create_physical_optimizer() -> LocalModeOptimizer {
1101 LocalModeOptimizer::with_cache(Arc::new(LiquidCacheParquet::new(
1102 1024,
1103 1024 * 1024 * 1024,
1104 PathBuf::from("test"),
1105 Box::new(LiquidPolicy::new()),
1106 Box::new(TranscodeSqueezeEvict),
1107 Box::new(AlwaysHydrate::new()),
1108 IoMode::Uring,
1109 )))
1110 }
1111
1112 fn create_session_context(optimizer: Arc<LineageOptimizer>) -> SessionContext {
1113 let state = SessionStateBuilder::new()
1114 .with_config(SessionConfig::new())
1115 .with_default_features()
1116 .with_optimizer_rule(optimizer as Arc<dyn OptimizerRule + Send + Sync>)
1117 .with_physical_optimizer_rule(Arc::new(create_physical_optimizer()))
1118 .build();
1119 SessionContext::new_with_state(state)
1120 }
1121
1122 fn write_date_parquet(path: &std::path::Path) {
1123 let schema = Arc::new(Schema::new(vec![
1124 Field::new(
1125 "event_ts",
1126 DataType::Timestamp(TimeUnit::Microsecond, None),
1127 false,
1128 ),
1129 Field::new("date", DataType::Date32, false),
1130 Field::new("date_copy", DataType::Date32, false),
1131 ]));
1132
1133 let timestamps: ArrayRef = Arc::new(TimestampMicrosecondArray::from(vec![
1134 Some(1_609_459_200_000_000),
1135 Some(1_640_995_200_000_000),
1136 Some(1_672_358_400_000_000),
1137 ]));
1138 let dates: ArrayRef = Arc::new(Date32Array::from(vec![
1139 Some(20210101),
1140 Some(20220202),
1141 Some(20230303),
1142 ]));
1143 let batch = arrow::record_batch::RecordBatch::try_new(
1144 Arc::clone(&schema),
1145 vec![timestamps, dates.clone(), dates],
1146 )
1147 .unwrap();
1148
1149 let file = std::fs::File::create(path).unwrap();
1150 let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), None).unwrap();
1151 writer.write(&batch).unwrap();
1152 writer.close().unwrap();
1153 }
1154
1155 fn write_variant_parquet(path: &std::path::Path) {
1156 let values = StringArray::from(vec![
1157 Some(r#"{"name": "Alice", "age": 30}"#),
1158 Some(r#"{"name": "Bob", "age": 25}"#),
1159 Some(r#"{"name": "Charlie"}"#),
1160 ]);
1161 let input_array: ArrayRef = Arc::new(values);
1162 let variant: VariantArray =
1163 json_to_variant(&input_array).expect("variant conversion for test data");
1164
1165 let schema = Arc::new(Schema::new(vec![variant.field("data")]));
1166 let batch = arrow::record_batch::RecordBatch::try_new(
1167 Arc::clone(&schema),
1168 vec![ArrayRef::from(variant)],
1169 )
1170 .expect("variant batch");
1171
1172 let file = std::fs::File::create(path).expect("create variant parquet file");
1173 let mut writer =
1174 ArrowWriter::try_new(file, batch.schema(), None).expect("create variant writer");
1175 writer.write(&batch).expect("write variant batch");
1176 writer.close().expect("close variant writer");
1177 }
1178
1179 async fn setup_single_date_table() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1181 let temp_dir = TempDir::new().unwrap();
1182 let table_path = temp_dir.path().join("table_a.parquet");
1183 write_date_parquet(&table_path);
1184
1185 let optimizer = Arc::new(LineageOptimizer::new());
1186 let ctx = create_session_context(optimizer.clone());
1187 ctx.register_parquet(
1188 "table_a",
1189 table_path.to_str().unwrap(),
1190 ParquetReadOptions::default(),
1191 )
1192 .await
1193 .unwrap();
1194
1195 (temp_dir, ctx, optimizer)
1196 }
1197
1198 async fn setup_dual_date_tables() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1200 let temp_dir = TempDir::new().unwrap();
1201 let table_a_path = temp_dir.path().join("table_a.parquet");
1202 let table_b_path = temp_dir.path().join("table_b.parquet");
1203 write_date_parquet(&table_a_path);
1204 write_date_parquet(&table_b_path);
1205
1206 let optimizer = Arc::new(LineageOptimizer::new());
1207 let ctx = create_session_context(optimizer.clone());
1208 ctx.register_parquet(
1209 "table_a",
1210 table_a_path.to_str().unwrap(),
1211 ParquetReadOptions::default(),
1212 )
1213 .await
1214 .unwrap();
1215 ctx.register_parquet(
1216 "table_b",
1217 table_b_path.to_str().unwrap(),
1218 ParquetReadOptions::default(),
1219 )
1220 .await
1221 .unwrap();
1222
1223 (temp_dir, ctx, optimizer)
1224 }
1225
1226 async fn setup_variant_table() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1228 let temp_dir = TempDir::new().unwrap();
1229 let variant_path = temp_dir.path().join("variants_test.parquet");
1230 write_variant_parquet(&variant_path);
1231
1232 let optimizer = Arc::new(LineageOptimizer::new());
1233 let ctx = create_session_context(optimizer.clone());
1234 ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default()));
1235 ctx.register_udf(ScalarUDF::new_from_impl(VariantToJsonUdf::default()));
1236 ctx.register_parquet(
1237 "variants_test",
1238 variant_path.to_str().unwrap(),
1239 ParquetReadOptions::default().skip_metadata(false),
1240 )
1241 .await
1242 .unwrap();
1243
1244 (temp_dir, ctx, optimizer)
1245 }
1246
1247 async fn setup_date_and_variant_tables() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1249 let temp_dir = TempDir::new().unwrap();
1250 let table_a_path = temp_dir.path().join("table_a.parquet");
1251 let variant_path = temp_dir.path().join("variants_test.parquet");
1252 write_date_parquet(&table_a_path);
1253 write_variant_parquet(&variant_path);
1254
1255 let optimizer = Arc::new(LineageOptimizer::new());
1256 let ctx = create_session_context(optimizer.clone());
1257 ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default()));
1258 ctx.register_udf(ScalarUDF::new_from_impl(VariantToJsonUdf::default()));
1259 ctx.register_parquet(
1260 "table_a",
1261 table_a_path.to_str().unwrap(),
1262 ParquetReadOptions::default(),
1263 )
1264 .await
1265 .unwrap();
1266 ctx.register_parquet(
1267 "variants_test",
1268 variant_path.to_str().unwrap(),
1269 ParquetReadOptions::default().skip_metadata(false),
1270 )
1271 .await
1272 .unwrap();
1273
1274 (temp_dir, ctx, optimizer)
1275 }
1276
1277 fn extract_field_metadata(
1282 plan: &Arc<dyn ExecutionPlan>,
1283 metadata_key: &str,
1284 ) -> HashMap<String, String> {
1285 let mut field_metadata_map = HashMap::new();
1286
1287 plan.apply(|node| {
1288 let Some(data_source) = node.as_any().downcast_ref::<DataSourceExec>() else {
1289 return Ok(TreeNodeRecursion::Continue);
1290 };
1291 let Some(file_scan_config) = data_source
1292 .data_source()
1293 .as_any()
1294 .downcast_ref::<FileScanConfig>()
1295 else {
1296 return Ok(TreeNodeRecursion::Continue);
1297 };
1298
1299 let file_schema = &file_scan_config.file_schema();
1300 for field in file_schema.fields() {
1301 if let Some(metadata_value) = field.metadata().get(metadata_key) {
1302 field_metadata_map.insert(field.name().to_string(), metadata_value.clone());
1303 }
1304 }
1305 Ok(TreeNodeRecursion::Continue)
1306 })
1307 .unwrap();
1308 field_metadata_map
1309 }
1310
1311 #[derive(Debug, Deserialize)]
1312 struct VariantMetadataEntry {
1313 path: String,
1314 #[serde(rename = "type")]
1315 data_type: Option<String>,
1316 }
1317
1318 fn parse_variant_metadata(value: &str) -> Vec<VariantMetadataEntry> {
1319 serde_json::from_str(value).unwrap_or_else(|_| {
1320 vec![VariantMetadataEntry {
1321 path: value.to_string(),
1322 data_type: None,
1323 }]
1324 })
1325 }
1326
1327 fn variant_paths_from_metadata(value: &str) -> Vec<String> {
1328 parse_variant_metadata(value)
1329 .into_iter()
1330 .map(|entry| entry.path)
1331 .collect()
1332 }
1333
1334 async fn assert_metadata(
1336 ctx: &SessionContext,
1337 sql: &str,
1338 expected_date: Vec<(&str, &str)>,
1339 expected_variant: Vec<&str>,
1340 ) {
1341 let df = ctx.sql(sql).await.unwrap();
1342 let (state, plan) = df.into_parts();
1343 let optimized = state.optimize(&plan).unwrap();
1344 let physical_plan = state.create_physical_plan(&optimized).await.unwrap();
1345
1346 let date_metadata = extract_field_metadata(&physical_plan, DATE_MAPPING_METADATA_KEY);
1347 let variant_metadata = extract_field_metadata(&physical_plan, VARIANT_MAPPING_METADATA_KEY);
1348
1349 let expected_date_map: HashMap<String, String> = expected_date
1351 .into_iter()
1352 .map(|(col, val)| (col.to_string(), val.to_string()))
1353 .collect();
1354 assert_eq!(
1355 date_metadata, expected_date_map,
1356 "date metadata mismatch for SQL: {}",
1357 sql
1358 );
1359
1360 if expected_variant.is_empty() {
1362 assert!(
1363 !variant_metadata.contains_key("data"),
1364 "variant metadata should not be present for SQL: {}",
1365 sql
1366 );
1367 } else {
1368 let mut actual = variant_metadata
1369 .get("data")
1370 .map(|v| variant_paths_from_metadata(v))
1371 .unwrap_or_default();
1372 actual.sort();
1373 let mut expected: Vec<String> = expected_variant
1374 .into_iter()
1375 .map(|s| s.to_string())
1376 .collect();
1377 expected.sort();
1378 assert_eq!(
1379 actual, expected,
1380 "variant metadata mismatch for SQL: {}",
1381 sql
1382 );
1383 }
1384 }
1385
1386 #[tokio::test]
1391 async fn extract_day_basic() {
1392 let (_dir, ctx, _) = setup_single_date_table().await;
1393 assert_metadata(
1394 &ctx,
1395 "SELECT EXTRACT(DAY FROM date) AS day FROM table_a",
1396 vec![("date", "DAY")],
1397 vec![],
1398 )
1399 .await;
1400 }
1401
1402 #[tokio::test]
1403 async fn extract_dow_basic() {
1404 let (_dir, ctx, _) = setup_single_date_table().await;
1405 assert_metadata(
1406 &ctx,
1407 "SELECT date_part('dow', date) AS dow FROM table_a",
1408 vec![("date", "DOW")],
1409 vec![],
1410 )
1411 .await;
1412 }
1413
1414 #[tokio::test]
1415 async fn extract_day_lowercase() {
1416 let (_dir, ctx, _) = setup_single_date_table().await;
1417 assert_metadata(
1418 &ctx,
1419 "SELECT EXTRACT(day FROM date) AS day FROM table_a",
1420 vec![("date", "DAY")],
1421 vec![],
1422 )
1423 .await;
1424 }
1425
1426 #[tokio::test]
1427 async fn extract_day_qualified_column() {
1428 let (_dir, ctx, _) = setup_single_date_table().await;
1429 assert_metadata(
1430 &ctx,
1431 "SELECT EXTRACT(DAY FROM table_a.date) FROM table_a",
1432 vec![("date", "DAY")],
1433 vec![],
1434 )
1435 .await;
1436 }
1437
1438 #[tokio::test]
1439 async fn extract_day_in_avg() {
1440 let (_dir, ctx, _) = setup_single_date_table().await;
1441 assert_metadata(
1442 &ctx,
1443 "SELECT AVG(EXTRACT(DAY FROM date)) AS avg_day FROM table_a",
1444 vec![("date", "DAY")],
1445 vec![],
1446 )
1447 .await;
1448 }
1449
1450 #[tokio::test]
1451 async fn extract_day_in_expression() {
1452 let (_dir, ctx, _) = setup_single_date_table().await;
1453 assert_metadata(
1454 &ctx,
1455 "SELECT AVG(EXTRACT(DAY FROM date) + 1) AS avg_day FROM table_a",
1456 vec![("date", "DAY")],
1457 vec![],
1458 )
1459 .await;
1460 }
1461
1462 #[tokio::test]
1463 async fn extract_day_in_subqueries() {
1464 let (_dir, ctx, _) = setup_single_date_table().await;
1465 assert_metadata(
1466 &ctx,
1467 "SELECT (SELECT MAX(EXTRACT(DAY FROM date)) FROM table_a) AS max_day, (SELECT MIN(EXTRACT(DAY FROM date)) FROM table_a) AS min_day",
1468 vec![("date", "DAY")],
1469 vec![],
1470 )
1471 .await;
1472 }
1473
1474 #[tokio::test]
1479 async fn extract_day_and_month() {
1480 let (_dir, ctx, _) = setup_single_date_table().await;
1481 assert_metadata(
1482 &ctx,
1483 "SELECT EXTRACT(DAY FROM date) AS day, EXTRACT(MONTH FROM date) AS month FROM table_a",
1484 vec![("date", "MONTH,DAY")],
1485 vec![],
1486 )
1487 .await;
1488 }
1489
1490 #[tokio::test]
1491 async fn extract_day_and_month_subqueries() {
1492 let (_dir, ctx, _) = setup_single_date_table().await;
1493 assert_metadata(
1494 &ctx,
1495 "SELECT (SELECT MAX(EXTRACT(DAY FROM date)) FROM table_a) AS max_day, (SELECT MIN(EXTRACT(Month FROM date)) FROM table_a) AS min_day",
1496 vec![("date", "MONTH,DAY")],
1497 vec![],
1498 )
1499 .await;
1500 }
1501
1502 #[tokio::test]
1507 async fn extract_from_joined_tables() {
1508 let (_dir, ctx, _) = setup_dual_date_tables().await;
1509 let df = ctx
1512 .sql("SELECT EXTRACT(YEAR FROM table_a.date) AS year, EXTRACT(DAY FROM table_b.date) AS day FROM table_a INNER JOIN table_b ON table_a.event_ts = table_b.event_ts")
1513 .await
1514 .unwrap();
1515 let (state, plan) = df.into_parts();
1516 let optimized = state.optimize(&plan).unwrap();
1517 let physical_plan = state.create_physical_plan(&optimized).await.unwrap();
1518 let metadata = extract_field_metadata(&physical_plan, DATE_MAPPING_METADATA_KEY);
1519
1520 assert!(
1522 metadata.contains_key("date"),
1523 "date column should have extraction metadata"
1524 );
1525 let value = metadata.get("date").unwrap();
1526 assert!(
1527 value == "YEAR" || value == "DAY",
1528 "expected YEAR or DAY, got {}",
1529 value
1530 );
1531 }
1532
1533 #[tokio::test]
1538 async fn no_extraction_with_interval_arithmetic() {
1539 let (_dir, ctx, _) = setup_single_date_table().await;
1540 assert_metadata(
1541 &ctx,
1542 "SELECT EXTRACT(DAY FROM date + INTERVAL '1 day') AS day FROM table_a",
1543 vec![],
1544 vec![],
1545 )
1546 .await;
1547 }
1548
1549 #[tokio::test]
1550 async fn no_extraction_when_column_used_directly() {
1551 let (_dir, ctx, _) = setup_single_date_table().await;
1552 assert_metadata(&ctx, "SELECT date FROM table_a", vec![], vec![]).await;
1553 }
1554
1555 #[tokio::test]
1556 async fn no_extraction_when_used_in_join_condition() {
1557 let (_dir, ctx, _) = setup_dual_date_tables().await;
1558 assert_metadata(
1559 &ctx,
1560 "SELECT EXTRACT(DAY FROM table_a.date) AS day FROM table_a INNER JOIN table_b ON table_a.date = table_b.date",
1561 vec![],
1562 vec![],
1563 )
1564 .await;
1565 }
1566
1567 #[tokio::test]
1568 async fn timestamp_extraction_supported() {
1569 let (_dir, ctx, _) = setup_single_date_table().await;
1570 assert_metadata(
1571 &ctx,
1572 "SELECT EXTRACT(YEAR FROM event_ts) AS year FROM table_a",
1573 vec![("event_ts", "YEAR")],
1574 vec![],
1575 )
1576 .await;
1577 }
1578
1579 #[tokio::test]
1580 async fn timestamp_dow_extraction() {
1581 let (_dir, ctx, _) = setup_single_date_table().await;
1582 assert_metadata(
1583 &ctx,
1584 "SELECT date_part('dow', event_ts) AS dow FROM table_a",
1585 vec![("event_ts", "DOW")],
1586 vec![],
1587 )
1588 .await;
1589 }
1590
1591 #[tokio::test]
1596 async fn metadata_only_on_extracted_fields() {
1597 let (_dir, ctx, _) = setup_single_date_table().await;
1598 assert_metadata(
1600 &ctx,
1601 "SELECT EXTRACT(YEAR FROM date) AS year, event_ts FROM table_a",
1602 vec![("date", "YEAR")],
1603 vec![],
1604 )
1605 .await;
1606 }
1607
1608 #[tokio::test]
1613 async fn variant_get_single_path() {
1614 let (_dir, ctx, _) = setup_variant_table().await;
1615 assert_metadata(
1616 &ctx,
1617 "SELECT variant_to_json(variant_get(data, 'name')) FROM variants_test",
1618 vec![],
1619 vec!["name"],
1620 )
1621 .await;
1622 }
1623
1624 #[tokio::test]
1625 async fn variant_get_duplicate_paths() {
1626 let (_dir, ctx, _) = setup_variant_table().await;
1627 assert_metadata(
1628 &ctx,
1629 "SELECT variant_get(data, 'name'), variant_get(data, 'name') AS name2 FROM variants_test",
1630 vec![],
1631 vec!["name"],
1632 )
1633 .await;
1634 }
1635
1636 #[tokio::test]
1637 async fn variant_get_with_to_json() {
1638 let (_dir, ctx, _) = setup_variant_table().await;
1639 assert_metadata(
1640 &ctx,
1641 "SELECT variant_to_json(variant_get(data, 'age')), variant_to_json(variant_get(data, 'age')) AS age2 FROM variants_test",
1642 vec![],
1643 vec!["age"],
1644 )
1645 .await;
1646 }
1647
1648 #[tokio::test]
1649 async fn variant_get_qualified_column() {
1650 let (_dir, ctx, _) = setup_variant_table().await;
1651 assert_metadata(
1652 &ctx,
1653 "SELECT variant_get(variants_test.data, 'name') as name1, variant_get(variants_test.data, 'name') as name2 FROM variants_test",
1654 vec![],
1655 vec!["name"],
1656 )
1657 .await;
1658 }
1659
1660 #[tokio::test]
1661 async fn variant_get_in_aggregates() {
1662 let (_dir, ctx, _) = setup_variant_table().await;
1663 assert_metadata(
1664 &ctx,
1665 "SELECT COUNT(variant_get(data, 'age')), MAX(variant_get(data, 'age')) FROM variants_test",
1666 vec![],
1667 vec!["age"],
1668 )
1669 .await;
1670 }
1671
1672 #[tokio::test]
1673 async fn variant_get_with_where_clause() {
1674 let (_dir, ctx, _) = setup_variant_table().await;
1675 assert_metadata(
1676 &ctx,
1677 "SELECT variant_get(data, 'name') FROM variants_test WHERE variant_get(data, 'name') IS NOT NULL",
1678 vec![],
1679 vec!["name"],
1680 )
1681 .await;
1682 }
1683
1684 #[tokio::test]
1685 async fn variant_get_in_subqueries() {
1686 let (_dir, ctx, _) = setup_variant_table().await;
1687 assert_metadata(
1688 &ctx,
1689 "SELECT (SELECT MAX(variant_get(data, 'name')) FROM variants_test) AS max_name, (SELECT MIN(variant_get(data, 'name')) FROM variants_test) AS min_name",
1690 vec![],
1691 vec!["name"],
1692 )
1693 .await;
1694 }
1695
1696 #[tokio::test]
1697 async fn variant_get_multiple_paths() {
1698 let (_dir, ctx, _) = setup_variant_table().await;
1699 assert_metadata(
1700 &ctx,
1701 "SELECT variant_get(data, 'name'), variant_get(data, 'date') FROM variants_test",
1702 vec![],
1703 vec!["date", "name"],
1704 )
1705 .await;
1706 }
1707
1708 #[tokio::test]
1709 async fn variant_get_nested() {
1710 let (_dir, ctx, _) = setup_variant_table().await;
1711 assert_metadata(
1712 &ctx,
1713 "SELECT variant_get(variant_get(data, 'name'), 'age') FROM variants_test",
1714 vec![],
1715 vec!["name"],
1716 )
1717 .await;
1718 }
1719
1720 #[tokio::test]
1721 async fn variant_get_conflicting_types_no_metadata() {
1722 let (_dir, ctx, _) = setup_variant_table().await;
1723 assert_metadata(
1725 &ctx,
1726 "SELECT variant_get(data, 'name', 'Utf8'), variant_get(data, 'name') FROM variants_test",
1727 vec![],
1728 vec![],
1729 )
1730 .await;
1731 }
1732
1733 #[tokio::test]
1738 async fn variant_get_type_hint_propagated() {
1739 let (_dir, ctx, _) = setup_variant_table().await;
1740
1741 let df = ctx
1742 .sql("SELECT variant_get(data, 'name', 'Utf8') FROM variants_test")
1743 .await
1744 .unwrap();
1745 let (state, plan) = df.into_parts();
1746 let optimized = state.optimize(&plan).unwrap();
1747 let physical_plan = state.create_physical_plan(&optimized).await.unwrap();
1748
1749 let metadata = extract_field_metadata(&physical_plan, VARIANT_MAPPING_METADATA_KEY);
1750
1751 let entries = metadata
1752 .get("data")
1753 .map(|value| parse_variant_metadata(value))
1754 .unwrap_or_default();
1755 let entry = entries
1756 .iter()
1757 .find(|entry| entry.path == "name")
1758 .expect("variant metadata entry for name");
1759 assert_eq!(entry.data_type.as_deref(), Some("Utf8"));
1760 }
1761
1762 #[tokio::test]
1763 async fn variant_get_conflicting_types_in_filter() {
1764 let (_dir, ctx, _) = setup_variant_table().await;
1765 assert_metadata(
1766 &ctx,
1767 "SELECT variant_to_json(variant_get(data, 'name')) FROM variants_test WHERE variant_get(data, 'name', 'Utf8') = 'Bob'",
1768 vec![],
1769 vec![],
1770 )
1771 .await;
1772 }
1773
1774 #[tokio::test]
1779 async fn variant_get_multiple_paths_with_types() {
1780 let (_dir, ctx, _) = setup_variant_table().await;
1781 assert_metadata(
1782 &ctx,
1783 "SELECT variant_get(data, 'did', 'Utf8') as user_id,
1784 MAX(TO_TIMESTAMP_MICROS(variant_get(data, 'time_us', 'Int64'))) - MIN(TO_TIMESTAMP_MICROS(variant_get(data, 'time_us', 'Int64')))
1785 FROM variants_test GROUP BY user_id",
1786 vec![],
1787 vec!["did", "time_us"],
1788 )
1789 .await;
1790 }
1791
1792 #[tokio::test]
1797 async fn mixed_date_and_variant_basic() {
1798 let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1799 assert_metadata(
1800 &ctx,
1801 "SELECT EXTRACT(DAY FROM table_a.date) AS day, variant_get(variants_test.data, 'name') AS name FROM table_a CROSS JOIN variants_test",
1802 vec![("date", "DAY")],
1803 vec!["name"],
1804 )
1805 .await;
1806 }
1807
1808 #[tokio::test]
1809 async fn mixed_multiple_date_and_variant() {
1810 let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1811 assert_metadata(
1812 &ctx,
1813 "SELECT EXTRACT(YEAR FROM table_a.date) AS year, EXTRACT(MONTH FROM table_a.date_copy) AS month, variant_get(variants_test.data, 'name') AS name, variant_get(variants_test.data, 'age') AS age FROM table_a CROSS JOIN variants_test",
1814 vec![("date", "YEAR"), ("date_copy", "MONTH")],
1815 vec!["age", "name"],
1816 )
1817 .await;
1818 }
1819
1820 #[tokio::test]
1821 async fn mixed_date_in_where_clause() {
1822 let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1823 assert_metadata(
1824 &ctx,
1825 "SELECT variant_get(variants_test.data, 'name') AS name FROM variants_test CROSS JOIN table_a WHERE EXTRACT(DAY FROM table_a.date) > 1",
1826 vec![("date", "DAY")],
1827 vec!["name"],
1828 )
1829 .await;
1830 }
1831
1832 #[tokio::test]
1833 async fn mixed_variant_in_where_clause() {
1834 let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1835 assert_metadata(
1836 &ctx,
1837 "SELECT EXTRACT(DAY FROM table_a.date) AS day FROM table_a CROSS JOIN variants_test WHERE variant_get(variants_test.data, 'name') IS NOT NULL",
1838 vec![("date", "DAY")],
1839 vec!["name"],
1840 )
1841 .await;
1842 }
1843
1844 #[tokio::test]
1845 async fn mixed_date_with_variant_subquery() {
1846 let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1847 assert_metadata(
1848 &ctx,
1849 "SELECT EXTRACT(YEAR FROM table_a.date) AS year, (SELECT variant_get(variants_test.data, 'name') FROM variants_test LIMIT 1) AS name FROM table_a",
1850 vec![("date", "YEAR")],
1851 vec!["name"],
1852 )
1853 .await;
1854 }
1855
1856 #[tokio::test]
1857 async fn mixed_in_aggregates() {
1858 let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1859 assert_metadata(
1860 &ctx,
1861 "SELECT AVG(EXTRACT(DAY FROM table_a.date)) AS avg_day, COUNT(variant_get(variants_test.data, 'name')) AS name_count FROM table_a CROSS JOIN variants_test",
1862 vec![("date", "DAY")],
1863 vec!["name"],
1864 )
1865 .await;
1866 }
1867}