Skip to main content

liquid_cache_datafusion/optimizers/
lineage_opt.rs

1//! This module has a logical optimizer that detects columns that are only used via compatible `EXTRACT` projections.
2//! It then attaches the metadata to expression adapter factory, which is then passed to the physical plan.
3//! The physical optimizer will move the metadata to the fields of the schema.
4
5use std::cmp::Ordering;
6use std::collections::hash_map::Entry;
7use std::collections::{HashMap, HashSet};
8use std::str::FromStr;
9use std::sync::{Arc, Mutex, OnceLock};
10
11use arrow::compute::kernels::cast_utils::IntervalUnit;
12use arrow_schema::{DataType, Schema, SchemaRef};
13use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
14use datafusion::common::{
15    Column, Constraints, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, TableReference,
16};
17use datafusion::datasource::listing::{ListingTable, ListingTableConfig};
18use datafusion::datasource::{TableProvider, provider_as_source, source_as_provider};
19use datafusion::logical_expr::logical_plan::{
20    Aggregate, Distinct, DistinctOn, Filter, Join, Limit, LogicalPlan, Partitioning, Projection,
21    Repartition, Sort, SubqueryAlias, TableScan, Union, Window,
22};
23use datafusion::logical_expr::{Expr, TableSource};
24use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
25use datafusion::physical_expr_adapter::{
26    DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
27};
28
29/// Supported components for `EXTRACT` clauses.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
31pub(crate) enum SupportedIntervalUnit {
32    Year,
33    Month,
34    Day,
35    DayOfWeek,
36}
37
38impl SupportedIntervalUnit {
39    pub(crate) fn metadata_value(self) -> &'static str {
40        match self {
41            SupportedIntervalUnit::Year => "YEAR",
42            SupportedIntervalUnit::Month => "MONTH",
43            SupportedIntervalUnit::Day => "DAY",
44            SupportedIntervalUnit::DayOfWeek => "DOW",
45        }
46    }
47}
48
49/// Metadata describing a Date32/Timestamp column that participates in an `EXTRACT`.
50#[derive(Debug, Clone, PartialEq, Eq)]
51pub(crate) struct DateExtraction {
52    pub(crate) column: Column,
53    pub(crate) components: HashSet<SupportedIntervalUnit>,
54}
55
56/// Metadata describing a Variant column that participates in a `variant_get`.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub(crate) struct VariantExtraction {
59    pub(crate) column: Column,
60    pub(crate) fields: Vec<VariantField>,
61}
62
63impl PartialOrd for VariantExtraction {
64    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
65        Some(self.cmp(other))
66    }
67}
68
69impl Ord for VariantExtraction {
70    fn cmp(&self, other: &Self) -> Ordering {
71        self.column
72            .flat_name()
73            .cmp(&other.column.flat_name())
74            .then_with(|| self.fields.cmp(&other.fields))
75    }
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub(crate) struct VariantField {
80    pub(crate) path: String,
81    pub(crate) data_type: Option<DataType>,
82}
83
84impl PartialOrd for VariantField {
85    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
86        Some(self.cmp(other))
87    }
88}
89
90impl Ord for VariantField {
91    fn cmp(&self, other: &Self) -> Ordering {
92        self.path.cmp(&other.path).then_with(|| {
93            let self_ty = self
94                .data_type
95                .as_ref()
96                .map(|dt| dt.to_string())
97                .unwrap_or_default();
98            let other_ty = other
99                .data_type
100                .as_ref()
101                .map(|dt| dt.to_string())
102                .unwrap_or_default();
103            self_ty.cmp(&other_ty)
104        })
105    }
106}
107
108/// Annotation that should be attached to a column in the file schema.
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub(crate) enum ColumnAnnotation {
111    DatePart(HashSet<SupportedIntervalUnit>),
112    VariantPaths(Vec<VariantField>),
113    SubstringSearch,
114}
115
116pub(crate) fn serialize_date_part(units: &HashSet<SupportedIntervalUnit>) -> String {
117    let mut sorted_units: Vec<&SupportedIntervalUnit> = units.iter().collect();
118    // Sort by a consistent order: Year, Month, Day, DayOfWeek
119    sorted_units.sort_by_key(|unit| match unit {
120        SupportedIntervalUnit::Year => 0,
121        SupportedIntervalUnit::Month => 1,
122        SupportedIntervalUnit::Day => 2,
123        SupportedIntervalUnit::DayOfWeek => 3,
124    });
125    sorted_units
126        .iter()
127        .map(|unit| unit.metadata_value())
128        .collect::<Vec<_>>()
129        .join(",")
130}
131
132/// Logical optimizer that analyses the logical plan to detect columns that
133/// are only used via compatible `EXTRACT` or `variant_get` projections.
134#[derive(Debug, Default)]
135pub struct LineageOptimizer;
136
137impl LineageOptimizer {
138    /// Create a new optimizer.
139    pub fn new() -> Self {
140        Self
141    }
142}
143
144impl OptimizerRule for LineageOptimizer {
145    fn name(&self) -> &str {
146        "LineageOptimizer"
147    }
148
149    fn apply_order(&self) -> Option<ApplyOrder> {
150        // so that it won't recursively apply the rule to every node.
151        None
152    }
153
154    fn rewrite(
155        &self,
156        plan: LogicalPlan,
157        _config: &dyn OptimizerConfig,
158    ) -> Result<Transformed<LogicalPlan>, DataFusionError> {
159        let mut analyzer = LineageAnalyzer::default();
160        let _ = analyzer.analyze_plan(&plan)?;
161        let table_usage = analyzer.finish();
162        let mut date_findings = table_usage.find_date32_extractions();
163        date_findings.sort_by(|a, b| a.column.flat_name().cmp(&b.column.flat_name()));
164
165        let mut variant_findings = table_usage.find_variant_gets();
166        variant_findings.sort();
167
168        let mut substring_findings = table_usage.find_substring_searches();
169        substring_findings.sort_by_key(|a| a.flat_name());
170
171        let annotations =
172            build_annotation_map(&date_findings, &variant_findings, &substring_findings);
173        annotate_plan_with_extractions(plan, &annotations)
174    }
175}
176
177type LineageMap = HashMap<ColumnKey, Vec<ColumnUsage>>;
178
179#[derive(Clone, Debug, PartialEq, Eq, Hash)]
180struct ColumnKey {
181    relation: Option<TableReference>,
182    name: String,
183}
184
185impl ColumnKey {
186    fn new(relation: Option<TableReference>, name: impl Into<String>) -> Self {
187        Self {
188            relation,
189            name: name.into(),
190        }
191    }
192
193    fn from_column(column: &Column) -> Self {
194        Self {
195            relation: column.relation.clone(),
196            name: column.name.clone(),
197        }
198    }
199
200    fn to_column(&self) -> Column {
201        Column::new(self.relation.clone(), self.name.clone())
202    }
203}
204
205#[derive(Debug, Clone)]
206struct ColumnUsage {
207    base: ColumnKey,
208    data_type: DataType,
209    operations: Vec<Operation>,
210}
211
212impl ColumnUsage {
213    fn new_base(column: &Column, data_type: DataType) -> Self {
214        Self {
215            base: ColumnKey::from_column(column),
216            data_type,
217            operations: Vec::new(),
218        }
219    }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
223enum Operation {
224    Extract(SupportedIntervalUnit),
225    VariantGet {
226        path: String,
227        data_type: Option<DataType>,
228    },
229    SubstringSearch,
230    Other,
231}
232
233#[derive(Debug)]
234struct UsageStats {
235    data_type: DataType,
236    usages: Vec<Vec<Operation>>,
237}
238
239impl UsageStats {
240    fn new(data_type: DataType) -> Self {
241        Self {
242            data_type,
243            usages: Vec::new(),
244        }
245    }
246
247    fn apply(&mut self, usage: &ColumnUsage) {
248        self.usages.push(usage.operations.clone());
249    }
250}
251
252#[derive(Default)]
253struct LineageAnalyzer {
254    stats: HashMap<ColumnKey, UsageStats>,
255}
256
257impl LineageAnalyzer {
258    fn analyze_plan(&mut self, plan: &LogicalPlan) -> Result<LineageMap> {
259        match plan {
260            LogicalPlan::TableScan(scan) => self.analyze_table_scan(scan),
261            LogicalPlan::Projection(projection) => self.analyze_projection(projection),
262            LogicalPlan::Filter(filter) => self.analyze_filter(filter),
263            LogicalPlan::Aggregate(aggregate) => self.analyze_aggregate(aggregate),
264            LogicalPlan::Sort(sort) => self.analyze_sort(sort),
265            LogicalPlan::Join(join) => self.analyze_join(join),
266            LogicalPlan::SubqueryAlias(alias) => self.analyze_subquery_alias(alias),
267            LogicalPlan::Window(window) => self.analyze_window(window),
268            LogicalPlan::Limit(limit) => self.analyze_limit(limit),
269            LogicalPlan::Repartition(repartition) => self.analyze_repartition(repartition),
270            LogicalPlan::Union(union) => self.analyze_union(union),
271            LogicalPlan::Distinct(distinct) => self.analyze_distinct(distinct),
272            other => {
273                let mut merged = LineageMap::new();
274                for input in other.inputs() {
275                    let child = self.analyze_plan(input)?;
276                    merged = merge_lineage_maps(merged, child);
277                }
278                Ok(merged)
279            }
280        }
281    }
282
283    fn analyze_table_scan(&mut self, scan: &TableScan) -> Result<LineageMap> {
284        let schema = scan.projected_schema.as_ref();
285        let mut map = LineageMap::new();
286        for (index, column) in schema.columns().iter().enumerate() {
287            let field = schema.field(index);
288            let usage = ColumnUsage::new_base(column, field.data_type().clone());
289            map.insert(usage.base.clone(), vec![usage]);
290        }
291
292        for filter in &scan.filters {
293            let usages = lineage_for_expr(filter, &map, schema)?;
294            self.record(&usages);
295        }
296
297        Ok(map)
298    }
299
300    fn analyze_projection(&mut self, projection: &Projection) -> Result<LineageMap> {
301        let input_map = self.analyze_plan(projection.input.as_ref())?;
302        let input_schema = projection.input.schema();
303        let mut output = LineageMap::new();
304        for (expr, column) in projection.expr.iter().zip(projection.schema.columns()) {
305            let usages = lineage_for_expr(expr, &input_map, input_schema.as_ref())?;
306            self.record(&usages);
307            output.insert(ColumnKey::from_column(&column), usages);
308        }
309        Ok(output)
310    }
311
312    fn analyze_filter(&mut self, filter: &Filter) -> Result<LineageMap> {
313        let input_map = self.analyze_plan(filter.input.as_ref())?;
314        let schema = filter.input.schema();
315        let usages = lineage_for_expr(&filter.predicate, &input_map, schema.as_ref())?;
316        self.record(&usages);
317        Ok(input_map)
318    }
319
320    fn analyze_sort(&mut self, sort: &Sort) -> Result<LineageMap> {
321        let input_map = self.analyze_plan(sort.input.as_ref())?;
322        let schema = sort.input.schema();
323        for sort_expr in &sort.expr {
324            let usages = lineage_for_expr(&sort_expr.expr, &input_map, schema.as_ref())?;
325            self.record(&usages);
326        }
327        Ok(input_map)
328    }
329
330    fn analyze_aggregate(&mut self, aggregate: &Aggregate) -> Result<LineageMap> {
331        let input_map = self.analyze_plan(aggregate.input.as_ref())?;
332        let schema = aggregate.input.schema();
333        let mut output = LineageMap::new();
334        let mut expr_iter = aggregate
335            .group_expr
336            .iter()
337            .chain(aggregate.aggr_expr.iter());
338
339        for column in aggregate.schema.columns() {
340            if let Some(expr) = expr_iter.next() {
341                let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
342                self.record(&usages);
343                output.insert(ColumnKey::from_column(&column), usages);
344            } else {
345                output.insert(ColumnKey::from_column(&column), Vec::new());
346            }
347        }
348
349        Ok(output)
350    }
351
352    fn analyze_join(&mut self, join: &Join) -> Result<LineageMap> {
353        let left_map = self.analyze_plan(join.left.as_ref())?;
354        let right_map = self.analyze_plan(join.right.as_ref())?;
355        let left_schema = join.left.schema();
356        let right_schema = join.right.schema();
357
358        for (left_expr, right_expr) in &join.on {
359            let left_usages = lineage_for_expr(left_expr, &left_map, left_schema.as_ref())?;
360            self.record(&left_usages);
361            let right_usages = lineage_for_expr(right_expr, &right_map, right_schema.as_ref())?;
362            self.record(&right_usages);
363        }
364
365        if let Some(filter) = &join.filter {
366            let mut combined = left_map.clone();
367            merge_lineage_map_inplace(&mut combined, &right_map);
368            let usages = lineage_for_expr(filter, &combined, join.schema.as_ref())?;
369            self.record(&usages);
370        }
371
372        let left_columns = left_schema.columns();
373        let right_columns = right_schema.columns();
374        let mut output_columns = join.schema.columns().into_iter();
375        let mut output = LineageMap::new();
376
377        for column in left_columns {
378            if let Some(output_column) = output_columns.next() {
379                let key = ColumnKey::from_column(&output_column);
380                let usages = left_map
381                    .get(&ColumnKey::from_column(&column))
382                    .cloned()
383                    .unwrap_or_default();
384                output.insert(key, usages);
385            }
386        }
387
388        for column in right_columns {
389            if let Some(output_column) = output_columns.next() {
390                let key = ColumnKey::from_column(&output_column);
391                let usages = right_map
392                    .get(&ColumnKey::from_column(&column))
393                    .cloned()
394                    .unwrap_or_default();
395                output.insert(key, usages);
396            }
397        }
398
399        Ok(output)
400    }
401
402    fn analyze_subquery_alias(&mut self, alias: &SubqueryAlias) -> Result<LineageMap> {
403        let input_map = self.analyze_plan(alias.input.as_ref())?;
404        let input_columns = alias.input.schema().columns();
405        let mut output = LineageMap::new();
406        for (input_column, output_column) in
407            input_columns.iter().zip(alias.schema.columns().into_iter())
408        {
409            let key = ColumnKey::from_column(&output_column);
410            let usages = input_map
411                .get(&ColumnKey::from_column(input_column))
412                .cloned()
413                .unwrap_or_default();
414            output.insert(key, usages);
415        }
416        Ok(output)
417    }
418
419    fn analyze_window(&mut self, window: &Window) -> Result<LineageMap> {
420        let input_map = self.analyze_plan(window.input.as_ref())?;
421        let input_schema = window.input.schema();
422
423        let input_cols = input_schema.columns();
424        let output_cols = window.schema.columns();
425        let mut output = LineageMap::new();
426
427        for (input_column, output_column) in input_cols.iter().zip(output_cols.iter()) {
428            let key = ColumnKey::from_column(output_column);
429            let usages = input_map
430                .get(&ColumnKey::from_column(input_column))
431                .cloned()
432                .unwrap_or_default();
433            output.insert(key, usages);
434        }
435
436        for (expr, output_column) in window
437            .window_expr
438            .iter()
439            .zip(output_cols.into_iter().skip(input_cols.len()))
440        {
441            let usages = lineage_for_expr(expr, &input_map, input_schema.as_ref())?;
442            self.record(&usages);
443            output.insert(ColumnKey::from_column(&output_column), usages);
444        }
445
446        Ok(output)
447    }
448
449    fn analyze_limit(&mut self, limit: &Limit) -> Result<LineageMap> {
450        let map = self.analyze_plan(limit.input.as_ref())?;
451        let schema = limit.input.schema();
452        if let Some(skip) = &limit.skip {
453            let usages = lineage_for_expr(skip, &map, schema.as_ref())?;
454            self.record(&usages);
455        }
456        if let Some(fetch) = &limit.fetch {
457            let usages = lineage_for_expr(fetch, &map, schema.as_ref())?;
458            self.record(&usages);
459        }
460        Ok(map)
461    }
462
463    fn analyze_repartition(&mut self, repartition: &Repartition) -> Result<LineageMap> {
464        let map = self.analyze_plan(repartition.input.as_ref())?;
465        let schema = repartition.input.schema();
466        if let Partitioning::Hash(exprs, _) | Partitioning::DistributeBy(exprs) =
467            &repartition.partitioning_scheme
468        {
469            for expr in exprs {
470                let usages = lineage_for_expr(expr, &map, schema.as_ref())?;
471                self.record(&usages);
472            }
473        }
474        Ok(map)
475    }
476
477    fn analyze_union(&mut self, union: &Union) -> Result<LineageMap> {
478        let mut input_maps: Vec<LineageMap> = Vec::with_capacity(union.inputs.len());
479        for input in &union.inputs {
480            input_maps.push(self.analyze_plan(input.as_ref())?);
481        }
482
483        let mut output = LineageMap::new();
484        for output_column in union.schema.columns() {
485            let key = ColumnKey::from_column(&output_column);
486            let mut combined: Vec<ColumnUsage> = Vec::new();
487            for map in &input_maps {
488                for (candidate_key, usages) in map {
489                    if candidate_key.name == key.name {
490                        combined.extend(usages.clone());
491                    }
492                }
493            }
494            output.insert(key, combined);
495        }
496        Ok(output)
497    }
498
499    fn analyze_distinct(&mut self, distinct: &Distinct) -> Result<LineageMap> {
500        match distinct {
501            Distinct::All(plan) => self.analyze_plan(plan.as_ref()),
502            Distinct::On(distinct_on) => self.analyze_distinct_on(distinct_on),
503        }
504    }
505
506    fn analyze_distinct_on(&mut self, distinct_on: &DistinctOn) -> Result<LineageMap> {
507        let input_map = self.analyze_plan(distinct_on.input.as_ref())?;
508        let schema = distinct_on.input.schema();
509
510        for expr in &distinct_on.on_expr {
511            let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
512            self.record(&usages);
513        }
514        for expr in &distinct_on.select_expr {
515            let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
516            self.record(&usages);
517        }
518        if let Some(sort_exprs) = &distinct_on.sort_expr {
519            for sort_expr in sort_exprs {
520                let usages = lineage_for_expr(&sort_expr.expr, &input_map, schema.as_ref())?;
521                self.record(&usages);
522            }
523        }
524
525        let mut output = LineageMap::new();
526        for (expr, column) in distinct_on
527            .select_expr
528            .iter()
529            .zip(distinct_on.schema.columns().into_iter())
530        {
531            let usages = lineage_for_expr(expr, &input_map, schema.as_ref())?;
532            output.insert(ColumnKey::from_column(&column), usages);
533        }
534        Ok(output)
535    }
536
537    fn record(&mut self, usages: &[ColumnUsage]) {
538        for usage in usages {
539            let entry = self
540                .stats
541                .entry(usage.base.clone())
542                .or_insert_with(|| UsageStats::new(usage.data_type.clone()));
543            entry.apply(usage);
544        }
545    }
546
547    fn finish(self) -> TableColumnUsage {
548        TableColumnUsage { usage: self.stats }
549    }
550}
551
552struct TableColumnUsage {
553    usage: HashMap<ColumnKey, UsageStats>,
554}
555
556impl TableColumnUsage {
557    fn find_date32_extractions(&self) -> Vec<DateExtraction> {
558        let mut extractions = Vec::new();
559        for (key, stats) in self.usage.iter() {
560            if is_date_part_type(&stats.data_type) {
561                // Collect all extract units from paths where the first n operations are all extracts
562                let mut all_units = HashSet::new();
563                let mut all_paths_valid = true;
564
565                for usage in &stats.usages {
566                    // Collect all Extract units from the leading sequence of extracts
567                    let mut path_units = HashSet::new();
568                    for op in usage {
569                        match op {
570                            Operation::Extract(unit) => {
571                                path_units.insert(unit);
572                            }
573                            _ => {
574                                // Stop at first non-extract operation
575                                break;
576                            }
577                        }
578                    }
579
580                    if path_units.is_empty() {
581                        // This path doesn't start with Extract, so skip this column
582                        all_paths_valid = false;
583                        break;
584                    }
585
586                    // Union the units from this path into the overall set
587                    all_units.extend(path_units);
588                }
589
590                if all_paths_valid && !all_units.is_empty() {
591                    extractions.push(DateExtraction {
592                        column: key.to_column(),
593                        components: all_units,
594                    });
595                }
596            }
597        }
598        extractions
599    }
600
601    fn find_variant_gets(&self) -> Vec<VariantExtraction> {
602        let mut gets = Vec::new();
603        for (key, stats) in self.usage.iter() {
604            if stats.usages.is_empty() {
605                continue;
606            }
607
608            let mut field_map: HashMap<String, VariantField> = HashMap::new();
609            let mut valid = true;
610            let mut saw_variant_get = false;
611            for usage in &stats.usages {
612                match usage.first() {
613                    Some(Operation::VariantGet { path, data_type }) => {
614                        saw_variant_get = true;
615                        match field_map.entry(path.clone()) {
616                            Entry::Vacant(entry) => {
617                                entry.insert(VariantField {
618                                    path: path.clone(),
619                                    data_type: data_type.clone(),
620                                });
621                            }
622                            Entry::Occupied(entry) => {
623                                let current = entry.into_mut();
624                                let conflict = match (&current.data_type, data_type) {
625                                    (Some(existing), Some(new_ty)) => existing != new_ty,
626                                    (Some(_), None) | (None, Some(_)) => true,
627                                    (None, None) => false,
628                                };
629                                if conflict {
630                                    valid = false;
631                                    break;
632                                }
633                            }
634                        }
635                    }
636                    // A passthrough of the base column (no operations) should not invalidate
637                    // the variant metadata, but also does not contribute a path.
638                    None => continue,
639                    _ => {
640                        valid = false;
641                        break;
642                    }
643                }
644            }
645
646            if valid && saw_variant_get && !field_map.is_empty() {
647                let mut fields: Vec<VariantField> = field_map.into_values().collect();
648                fields.sort();
649                gets.push(VariantExtraction {
650                    column: key.to_column(),
651                    fields,
652                });
653            }
654        }
655        gets
656    }
657
658    fn find_substring_searches(&self) -> Vec<Column> {
659        let mut columns = Vec::new();
660        for (key, stats) in self.usage.iter() {
661            if !is_string_type(&stats.data_type) {
662                continue;
663            }
664
665            if stats.usages.is_empty() {
666                continue;
667            }
668
669            let mut saw_substring = false;
670            let mut valid = true;
671            for usage in &stats.usages {
672                let has_substring = usage
673                    .iter()
674                    .any(|op| matches!(op, Operation::SubstringSearch));
675                if has_substring {
676                    saw_substring = true;
677                    continue;
678                }
679                if !usage.is_empty() {
680                    valid = false;
681                    break;
682                }
683            }
684
685            if valid && saw_substring {
686                columns.push(key.to_column());
687            }
688        }
689        columns
690    }
691}
692
693fn build_annotation_map(
694    date_findings: &[DateExtraction],
695    variant_findings: &[VariantExtraction],
696    substring_findings: &[Column],
697) -> HashMap<ColumnKey, ColumnAnnotation> {
698    let mut annotations: HashMap<ColumnKey, ColumnAnnotation> = HashMap::new();
699    for extraction in date_findings {
700        annotations.insert(
701            ColumnKey::from_column(&extraction.column),
702            ColumnAnnotation::DatePart(extraction.components.clone()),
703        );
704    }
705    for extraction in variant_findings {
706        annotations.insert(
707            ColumnKey::from_column(&extraction.column),
708            ColumnAnnotation::VariantPaths(extraction.fields.clone()),
709        );
710    }
711    for column in substring_findings {
712        annotations.insert(
713            ColumnKey::from_column(column),
714            ColumnAnnotation::SubstringSearch,
715        );
716    }
717    annotations
718}
719
720fn annotate_plan_with_extractions(
721    plan: LogicalPlan,
722    annotations: &HashMap<ColumnKey, ColumnAnnotation>,
723) -> Result<Transformed<LogicalPlan>, DataFusionError> {
724    if annotations.is_empty() {
725        return Ok(Transformed::no(plan));
726    }
727
728    plan.transform_up(|logical_plan| match logical_plan {
729        LogicalPlan::TableScan(mut scan) => {
730            let table_annotations = annotations_for_table_scan(&scan, annotations);
731            let mut changed = false;
732
733            if let Some(source) = annotate_listing_table_source(&scan.source, &table_annotations)? {
734                scan.source = source;
735                changed = true;
736            }
737
738            if changed {
739                Ok(Transformed::yes(LogicalPlan::TableScan(scan)))
740            } else {
741                Ok(Transformed::no(LogicalPlan::TableScan(scan)))
742            }
743        }
744        other => Ok(Transformed::no(other)),
745    })
746}
747
748fn annotations_for_table_scan(
749    scan: &TableScan,
750    annotations: &HashMap<ColumnKey, ColumnAnnotation>,
751) -> HashMap<String, ColumnAnnotation> {
752    let mut table_annotations = HashMap::new();
753
754    for (qualifier_opt, field_ref) in scan.projected_schema.iter() {
755        let qualifier_owned = qualifier_opt.cloned();
756        let name = field_ref.name().clone();
757        if let Some(unit) = annotations
758            .get(&ColumnKey::new(qualifier_owned.clone(), name.clone()))
759            .cloned()
760            .or_else(|| {
761                annotations
762                    .get(&ColumnKey::new(None, name.clone()))
763                    .cloned()
764            })
765        {
766            table_annotations.insert(name, unit);
767        }
768    }
769
770    table_annotations
771}
772
773fn annotate_listing_table_source(
774    source: &Arc<dyn TableSource>,
775    annotations: &HashMap<String, ColumnAnnotation>,
776) -> Result<Option<Arc<dyn TableSource>>, DataFusionError> {
777    if annotations.is_empty() {
778        return Ok(None);
779    }
780
781    let provider = match source_as_provider(source) {
782        Ok(provider) => provider,
783        Err(_) => return Ok(None),
784    };
785
786    let Some(listing) = provider.as_any().downcast_ref::<ListingTable>() else {
787        return Ok(None);
788    };
789
790    let metadata_copy = annotations.clone();
791    let new_factory: Arc<dyn PhysicalExprAdapterFactory> = Arc::new(
792        LineageExtractPhysicalExprAdapterFactory::new(annotations.clone()),
793    );
794    register_factory_metadata(&new_factory, metadata_copy);
795    let mut new_listing = ListingTable::try_new(
796        ListingTableConfig::new_with_multi_paths(listing.table_paths().clone())
797            .with_listing_options(listing.options().clone())
798            .with_schema(listing_file_schema(listing))
799            .with_expr_adapter_factory(new_factory),
800    )?;
801    new_listing = new_listing.with_constraints(listing_constraints(listing));
802    new_listing = new_listing.with_definition(
803        listing
804            .get_table_definition()
805            .map(std::string::ToString::to_string),
806    );
807
808    let new_provider: Arc<dyn TableProvider> = Arc::new(new_listing);
809    Ok(Some(provider_as_source(new_provider)))
810}
811
812#[derive(Debug)]
813struct LineageExtractPhysicalExprAdapterFactory {
814    base: Arc<dyn PhysicalExprAdapterFactory>,
815    _annotations: HashMap<String, ColumnAnnotation>,
816}
817
818impl LineageExtractPhysicalExprAdapterFactory {
819    fn new(annotations: HashMap<String, ColumnAnnotation>) -> Self {
820        Self {
821            base: Arc::new(DefaultPhysicalExprAdapterFactory),
822            _annotations: annotations,
823        }
824    }
825}
826
827impl PhysicalExprAdapterFactory for LineageExtractPhysicalExprAdapterFactory {
828    fn create(
829        &self,
830        logical_file_schema: SchemaRef,
831        physical_file_schema: SchemaRef,
832    ) -> Arc<dyn PhysicalExprAdapter> {
833        self.base.create(logical_file_schema, physical_file_schema)
834    }
835}
836
837fn listing_file_schema(listing: &ListingTable) -> SchemaRef {
838    let table_schema = listing.schema();
839    let file_field_count = table_schema
840        .fields()
841        .len()
842        .saturating_sub(listing.options().table_partition_cols.len());
843    let fields = table_schema
844        .fields()
845        .iter()
846        .take(file_field_count)
847        .cloned()
848        .collect::<Vec<_>>();
849    Arc::new(Schema::new(fields).with_metadata(table_schema.metadata().clone()))
850}
851
852fn listing_constraints(listing: &ListingTable) -> Constraints {
853    listing.constraints().cloned().unwrap_or_default()
854}
855
856fn factory_registry() -> &'static Mutex<HashMap<usize, HashMap<String, ColumnAnnotation>>> {
857    static REGISTRY: OnceLock<Mutex<HashMap<usize, HashMap<String, ColumnAnnotation>>>> =
858        OnceLock::new();
859    REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
860}
861
862fn register_factory_metadata(
863    factory: &Arc<dyn PhysicalExprAdapterFactory>,
864    metadata: HashMap<String, ColumnAnnotation>,
865) {
866    let key = Arc::as_ptr(factory) as *const () as usize;
867    factory_registry().lock().unwrap().insert(key, metadata);
868}
869
870pub(crate) fn metadata_from_factory(
871    factory: &Arc<dyn PhysicalExprAdapterFactory>,
872    column: &str,
873) -> Option<ColumnAnnotation> {
874    let key = Arc::as_ptr(factory) as *const () as usize;
875    factory_registry()
876        .lock()
877        .unwrap()
878        .get(&key)
879        .and_then(|map| map.get(column).cloned())
880}
881
882fn merge_lineage_maps(mut base: LineageMap, other: LineageMap) -> LineageMap {
883    for (key, usages) in other {
884        base.entry(key).or_default().extend(usages);
885    }
886    base
887}
888
889fn merge_lineage_map_inplace(base: &mut LineageMap, other: &LineageMap) {
890    for (key, usages) in other {
891        base.entry(key.clone()).or_default().extend(usages.clone());
892    }
893}
894
895fn lineage_for_expr(
896    expr: &Expr,
897    input_lineage: &LineageMap,
898    schema: &DFSchema,
899) -> Result<Vec<ColumnUsage>> {
900    match expr {
901        Expr::Column(column) => {
902            let key = ColumnKey::from_column(column);
903            if let Some(usages) = input_lineage.get(&key) {
904                Ok(usages.clone())
905            } else {
906                let field = schema.field_from_column(column)?;
907                Ok(vec![ColumnUsage::new_base(
908                    column,
909                    field.data_type().clone(),
910                )])
911            }
912        }
913        Expr::Alias(alias) => lineage_for_expr(&alias.expr, input_lineage, schema),
914        Expr::ScalarFunction(func) => {
915            let func_name = func.func.name();
916            if func_name.eq_ignore_ascii_case("date_part")
917                && func.args.len() == 2
918                && let Some(component) = part_to_unit(&func.args[0])
919            {
920                let mut usages = lineage_for_expr(&func.args[1], input_lineage, schema)?;
921                for usage in &mut usages {
922                    usage.operations.push(Operation::Extract(component));
923                }
924                return Ok(usages);
925            } else if func_name.eq_ignore_ascii_case("variant_get")
926                && (func.args.len() == 2 || func.args.len() == 3)
927                && let Some(path) = literal_utf8(&func.args[1])
928            {
929                let type_hint = func.args.get(2).and_then(literal_data_type);
930                let mut usages = lineage_for_expr(&func.args[0], input_lineage, schema)?;
931                for usage in &mut usages {
932                    usage.operations.push(Operation::VariantGet {
933                        path: path.clone(),
934                        data_type: type_hint.clone(),
935                    });
936                }
937                return Ok(usages);
938            }
939            propagate_other(expr, input_lineage, schema)
940        }
941        Expr::Like(like) => {
942            if !like.case_insensitive
943                && like.escape_char.is_none()
944                && let Some(pattern) = literal_utf8(&like.pattern)
945                && is_substring_pattern(pattern.as_bytes())
946            {
947                let mut usages = lineage_for_expr(&like.expr, input_lineage, schema)?;
948                for usage in &mut usages {
949                    usage.operations.push(Operation::SubstringSearch);
950                }
951                return Ok(usages);
952            }
953            propagate_other(expr, input_lineage, schema)
954        }
955        Expr::Cast(cast) => {
956            let mut usages = lineage_for_expr(&cast.expr, input_lineage, schema)?;
957            for usage in &mut usages {
958                usage.operations.push(Operation::Other);
959            }
960            Ok(usages)
961        }
962        Expr::TryCast(cast) => {
963            let mut usages = lineage_for_expr(&cast.expr, input_lineage, schema)?;
964            for usage in &mut usages {
965                usage.operations.push(Operation::Other);
966            }
967            Ok(usages)
968        }
969        Expr::Literal(_, _) => Ok(Vec::new()),
970        Expr::ScalarSubquery(_) | Expr::Exists { .. } => Ok(Vec::new()),
971        Expr::Placeholder(_) => Ok(Vec::new()),
972        #[allow(deprecated)]
973        Expr::Wildcard { .. } => {
974            let mut usages = Vec::new();
975            for column_usages in input_lineage.values() {
976                usages.extend(column_usages.clone());
977            }
978            Ok(usages)
979        }
980        _ => propagate_other(expr, input_lineage, schema),
981    }
982}
983
984fn propagate_other(
985    expr: &Expr,
986    input_lineage: &LineageMap,
987    schema: &DFSchema,
988) -> Result<Vec<ColumnUsage>> {
989    let mut combined: Vec<ColumnUsage> = Vec::new();
990    expr.apply_children(|child| {
991        let mut usages = lineage_for_expr(child, input_lineage, schema)?;
992        for usage in &mut usages {
993            usage.operations.push(Operation::Other);
994        }
995        combined.extend(usages);
996        Ok(TreeNodeRecursion::Continue)
997    })?;
998    Ok(combined)
999}
1000
1001fn literal_utf8(expr: &Expr) -> Option<String> {
1002    match expr {
1003        Expr::Literal(value, _) => match value {
1004            ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) => Some(v.clone()),
1005            ScalarValue::Utf8View(Some(v)) => Some(v.clone()),
1006            _ => None,
1007        },
1008        _ => None,
1009    }
1010}
1011
1012fn is_substring_pattern(pattern: &[u8]) -> bool {
1013    if pattern.len() < 2 {
1014        return false;
1015    }
1016    if pattern[0] != b'%' || pattern[pattern.len() - 1] != b'%' {
1017        return false;
1018    }
1019    let inner = &pattern[1..pattern.len() - 1];
1020    if inner.is_empty() {
1021        return false;
1022    }
1023    !inner.iter().any(|b| *b == b'%' || *b == b'_')
1024}
1025
1026fn is_string_type(data_type: &DataType) -> bool {
1027    match data_type {
1028        DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => true,
1029        DataType::Dictionary(_, value_type) => is_string_type(value_type.as_ref()),
1030        _ => false,
1031    }
1032}
1033
1034fn is_date_part_type(data_type: &DataType) -> bool {
1035    matches!(data_type, DataType::Date32 | DataType::Timestamp(_, _))
1036}
1037
1038fn literal_data_type(expr: &Expr) -> Option<DataType> {
1039    literal_utf8(expr).and_then(|spec| DataType::from_str(&spec).ok())
1040}
1041
1042fn part_to_unit(expr: &Expr) -> Option<SupportedIntervalUnit> {
1043    let value = match expr {
1044        Expr::Literal(literal, _) => literal,
1045        _ => return None,
1046    };
1047    let text = match value {
1048        ScalarValue::Utf8(Some(v))
1049        | ScalarValue::LargeUtf8(Some(v))
1050        | ScalarValue::Utf8View(Some(v)) => v.as_str(),
1051        _ => return None,
1052    };
1053    let lowered = text.to_ascii_lowercase();
1054    match lowered.as_str() {
1055        "dow" | "dayofweek" | "day_of_week" => {
1056            return Some(SupportedIntervalUnit::DayOfWeek);
1057        }
1058        _ => {}
1059    }
1060    let unit = IntervalUnit::from_str(text).ok()?;
1061    match unit {
1062        IntervalUnit::Year => Some(SupportedIntervalUnit::Year),
1063        IntervalUnit::Month => Some(SupportedIntervalUnit::Month),
1064        IntervalUnit::Day => Some(SupportedIntervalUnit::Day),
1065        _ => None,
1066    }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071    use std::path::PathBuf;
1072
1073    use crate::optimizers::{
1074        DATE_MAPPING_METADATA_KEY, LocalModeOptimizer, VARIANT_MAPPING_METADATA_KEY,
1075    };
1076    use crate::{LiquidCacheParquet, VariantGetUdf, VariantToJsonUdf};
1077    use liquid_cache::cache::AlwaysHydrate;
1078    use liquid_cache_common::IoMode;
1079
1080    use super::*;
1081    use arrow::array::{ArrayRef, Date32Array, StringArray, TimestampMicrosecondArray};
1082    use arrow_schema::{Field, Schema, TimeUnit};
1083    use datafusion::catalog::memory::DataSourceExec;
1084    use datafusion::datasource::physical_plan::FileScanConfig;
1085    use datafusion::execution::SessionStateBuilder;
1086    use datafusion::logical_expr::ScalarUDF;
1087    use datafusion::physical_plan::ExecutionPlan;
1088    use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
1089    use liquid_cache::cache::squeeze_policies::TranscodeSqueezeEvict;
1090    use liquid_cache::cache_policies::LiquidPolicy;
1091    use parquet::arrow::ArrowWriter;
1092    use parquet::variant::{VariantArray, json_to_variant};
1093    use serde::Deserialize;
1094    use tempfile::TempDir;
1095
1096    // ─────────────────────────────────────────────────────────────────────────────
1097    // Setup helpers - lean versions for different test scenarios
1098    // ─────────────────────────────────────────────────────────────────────────────
1099
1100    fn create_physical_optimizer() -> LocalModeOptimizer {
1101        LocalModeOptimizer::with_cache(Arc::new(LiquidCacheParquet::new(
1102            1024,
1103            1024 * 1024 * 1024,
1104            PathBuf::from("test"),
1105            Box::new(LiquidPolicy::new()),
1106            Box::new(TranscodeSqueezeEvict),
1107            Box::new(AlwaysHydrate::new()),
1108            IoMode::Uring,
1109        )))
1110    }
1111
1112    fn create_session_context(optimizer: Arc<LineageOptimizer>) -> SessionContext {
1113        let state = SessionStateBuilder::new()
1114            .with_config(SessionConfig::new())
1115            .with_default_features()
1116            .with_optimizer_rule(optimizer as Arc<dyn OptimizerRule + Send + Sync>)
1117            .with_physical_optimizer_rule(Arc::new(create_physical_optimizer()))
1118            .build();
1119        SessionContext::new_with_state(state)
1120    }
1121
1122    fn write_date_parquet(path: &std::path::Path) {
1123        let schema = Arc::new(Schema::new(vec![
1124            Field::new(
1125                "event_ts",
1126                DataType::Timestamp(TimeUnit::Microsecond, None),
1127                false,
1128            ),
1129            Field::new("date", DataType::Date32, false),
1130            Field::new("date_copy", DataType::Date32, false),
1131        ]));
1132
1133        let timestamps: ArrayRef = Arc::new(TimestampMicrosecondArray::from(vec![
1134            Some(1_609_459_200_000_000),
1135            Some(1_640_995_200_000_000),
1136            Some(1_672_358_400_000_000),
1137        ]));
1138        let dates: ArrayRef = Arc::new(Date32Array::from(vec![
1139            Some(20210101),
1140            Some(20220202),
1141            Some(20230303),
1142        ]));
1143        let batch = arrow::record_batch::RecordBatch::try_new(
1144            Arc::clone(&schema),
1145            vec![timestamps, dates.clone(), dates],
1146        )
1147        .unwrap();
1148
1149        let file = std::fs::File::create(path).unwrap();
1150        let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), None).unwrap();
1151        writer.write(&batch).unwrap();
1152        writer.close().unwrap();
1153    }
1154
1155    fn write_variant_parquet(path: &std::path::Path) {
1156        let values = StringArray::from(vec![
1157            Some(r#"{"name": "Alice", "age": 30}"#),
1158            Some(r#"{"name": "Bob", "age": 25}"#),
1159            Some(r#"{"name": "Charlie"}"#),
1160        ]);
1161        let input_array: ArrayRef = Arc::new(values);
1162        let variant: VariantArray =
1163            json_to_variant(&input_array).expect("variant conversion for test data");
1164
1165        let schema = Arc::new(Schema::new(vec![variant.field("data")]));
1166        let batch = arrow::record_batch::RecordBatch::try_new(
1167            Arc::clone(&schema),
1168            vec![ArrayRef::from(variant)],
1169        )
1170        .expect("variant batch");
1171
1172        let file = std::fs::File::create(path).expect("create variant parquet file");
1173        let mut writer =
1174            ArrowWriter::try_new(file, batch.schema(), None).expect("create variant writer");
1175        writer.write(&batch).expect("write variant batch");
1176        writer.close().expect("close variant writer");
1177    }
1178
1179    /// Setup for tests that only need a single date table (table_a)
1180    async fn setup_single_date_table() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1181        let temp_dir = TempDir::new().unwrap();
1182        let table_path = temp_dir.path().join("table_a.parquet");
1183        write_date_parquet(&table_path);
1184
1185        let optimizer = Arc::new(LineageOptimizer::new());
1186        let ctx = create_session_context(optimizer.clone());
1187        ctx.register_parquet(
1188            "table_a",
1189            table_path.to_str().unwrap(),
1190            ParquetReadOptions::default(),
1191        )
1192        .await
1193        .unwrap();
1194
1195        (temp_dir, ctx, optimizer)
1196    }
1197
1198    /// Setup for tests that need two date tables (table_a and table_b) for joins
1199    async fn setup_dual_date_tables() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1200        let temp_dir = TempDir::new().unwrap();
1201        let table_a_path = temp_dir.path().join("table_a.parquet");
1202        let table_b_path = temp_dir.path().join("table_b.parquet");
1203        write_date_parquet(&table_a_path);
1204        write_date_parquet(&table_b_path);
1205
1206        let optimizer = Arc::new(LineageOptimizer::new());
1207        let ctx = create_session_context(optimizer.clone());
1208        ctx.register_parquet(
1209            "table_a",
1210            table_a_path.to_str().unwrap(),
1211            ParquetReadOptions::default(),
1212        )
1213        .await
1214        .unwrap();
1215        ctx.register_parquet(
1216            "table_b",
1217            table_b_path.to_str().unwrap(),
1218            ParquetReadOptions::default(),
1219        )
1220        .await
1221        .unwrap();
1222
1223        (temp_dir, ctx, optimizer)
1224    }
1225
1226    /// Setup for tests that only need a variant table
1227    async fn setup_variant_table() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1228        let temp_dir = TempDir::new().unwrap();
1229        let variant_path = temp_dir.path().join("variants_test.parquet");
1230        write_variant_parquet(&variant_path);
1231
1232        let optimizer = Arc::new(LineageOptimizer::new());
1233        let ctx = create_session_context(optimizer.clone());
1234        ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default()));
1235        ctx.register_udf(ScalarUDF::new_from_impl(VariantToJsonUdf::default()));
1236        ctx.register_parquet(
1237            "variants_test",
1238            variant_path.to_str().unwrap(),
1239            ParquetReadOptions::default().skip_metadata(false),
1240        )
1241        .await
1242        .unwrap();
1243
1244        (temp_dir, ctx, optimizer)
1245    }
1246
1247    /// Setup for tests that need both date table and variant table
1248    async fn setup_date_and_variant_tables() -> (TempDir, SessionContext, Arc<LineageOptimizer>) {
1249        let temp_dir = TempDir::new().unwrap();
1250        let table_a_path = temp_dir.path().join("table_a.parquet");
1251        let variant_path = temp_dir.path().join("variants_test.parquet");
1252        write_date_parquet(&table_a_path);
1253        write_variant_parquet(&variant_path);
1254
1255        let optimizer = Arc::new(LineageOptimizer::new());
1256        let ctx = create_session_context(optimizer.clone());
1257        ctx.register_udf(ScalarUDF::new_from_impl(VariantGetUdf::default()));
1258        ctx.register_udf(ScalarUDF::new_from_impl(VariantToJsonUdf::default()));
1259        ctx.register_parquet(
1260            "table_a",
1261            table_a_path.to_str().unwrap(),
1262            ParquetReadOptions::default(),
1263        )
1264        .await
1265        .unwrap();
1266        ctx.register_parquet(
1267            "variants_test",
1268            variant_path.to_str().unwrap(),
1269            ParquetReadOptions::default().skip_metadata(false),
1270        )
1271        .await
1272        .unwrap();
1273
1274        (temp_dir, ctx, optimizer)
1275    }
1276
1277    // ─────────────────────────────────────────────────────────────────────────────
1278    // Test utilities
1279    // ─────────────────────────────────────────────────────────────────────────────
1280
1281    fn extract_field_metadata(
1282        plan: &Arc<dyn ExecutionPlan>,
1283        metadata_key: &str,
1284    ) -> HashMap<String, String> {
1285        let mut field_metadata_map = HashMap::new();
1286
1287        plan.apply(|node| {
1288            let Some(data_source) = node.as_any().downcast_ref::<DataSourceExec>() else {
1289                return Ok(TreeNodeRecursion::Continue);
1290            };
1291            let Some(file_scan_config) = data_source
1292                .data_source()
1293                .as_any()
1294                .downcast_ref::<FileScanConfig>()
1295            else {
1296                return Ok(TreeNodeRecursion::Continue);
1297            };
1298
1299            let file_schema = &file_scan_config.file_schema();
1300            for field in file_schema.fields() {
1301                if let Some(metadata_value) = field.metadata().get(metadata_key) {
1302                    field_metadata_map.insert(field.name().to_string(), metadata_value.clone());
1303                }
1304            }
1305            Ok(TreeNodeRecursion::Continue)
1306        })
1307        .unwrap();
1308        field_metadata_map
1309    }
1310
1311    #[derive(Debug, Deserialize)]
1312    struct VariantMetadataEntry {
1313        path: String,
1314        #[serde(rename = "type")]
1315        data_type: Option<String>,
1316    }
1317
1318    fn parse_variant_metadata(value: &str) -> Vec<VariantMetadataEntry> {
1319        serde_json::from_str(value).unwrap_or_else(|_| {
1320            vec![VariantMetadataEntry {
1321                path: value.to_string(),
1322                data_type: None,
1323            }]
1324        })
1325    }
1326
1327    fn variant_paths_from_metadata(value: &str) -> Vec<String> {
1328        parse_variant_metadata(value)
1329            .into_iter()
1330            .map(|entry| entry.path)
1331            .collect()
1332    }
1333
1334    /// Assert metadata on physical plan matches expected date and variant extractions
1335    async fn assert_metadata(
1336        ctx: &SessionContext,
1337        sql: &str,
1338        expected_date: Vec<(&str, &str)>,
1339        expected_variant: Vec<&str>,
1340    ) {
1341        let df = ctx.sql(sql).await.unwrap();
1342        let (state, plan) = df.into_parts();
1343        let optimized = state.optimize(&plan).unwrap();
1344        let physical_plan = state.create_physical_plan(&optimized).await.unwrap();
1345
1346        let date_metadata = extract_field_metadata(&physical_plan, DATE_MAPPING_METADATA_KEY);
1347        let variant_metadata = extract_field_metadata(&physical_plan, VARIANT_MAPPING_METADATA_KEY);
1348
1349        // Check date metadata
1350        let expected_date_map: HashMap<String, String> = expected_date
1351            .into_iter()
1352            .map(|(col, val)| (col.to_string(), val.to_string()))
1353            .collect();
1354        assert_eq!(
1355            date_metadata, expected_date_map,
1356            "date metadata mismatch for SQL: {}",
1357            sql
1358        );
1359
1360        // Check variant metadata
1361        if expected_variant.is_empty() {
1362            assert!(
1363                !variant_metadata.contains_key("data"),
1364                "variant metadata should not be present for SQL: {}",
1365                sql
1366            );
1367        } else {
1368            let mut actual = variant_metadata
1369                .get("data")
1370                .map(|v| variant_paths_from_metadata(v))
1371                .unwrap_or_default();
1372            actual.sort();
1373            let mut expected: Vec<String> = expected_variant
1374                .into_iter()
1375                .map(|s| s.to_string())
1376                .collect();
1377            expected.sort();
1378            assert_eq!(
1379                actual, expected,
1380                "variant metadata mismatch for SQL: {}",
1381                sql
1382            );
1383        }
1384    }
1385
1386    // ─────────────────────────────────────────────────────────────────────────────
1387    // Date extraction tests - single table
1388    // ─────────────────────────────────────────────────────────────────────────────
1389
1390    #[tokio::test]
1391    async fn extract_day_basic() {
1392        let (_dir, ctx, _) = setup_single_date_table().await;
1393        assert_metadata(
1394            &ctx,
1395            "SELECT EXTRACT(DAY FROM date) AS day FROM table_a",
1396            vec![("date", "DAY")],
1397            vec![],
1398        )
1399        .await;
1400    }
1401
1402    #[tokio::test]
1403    async fn extract_dow_basic() {
1404        let (_dir, ctx, _) = setup_single_date_table().await;
1405        assert_metadata(
1406            &ctx,
1407            "SELECT date_part('dow', date) AS dow FROM table_a",
1408            vec![("date", "DOW")],
1409            vec![],
1410        )
1411        .await;
1412    }
1413
1414    #[tokio::test]
1415    async fn extract_day_lowercase() {
1416        let (_dir, ctx, _) = setup_single_date_table().await;
1417        assert_metadata(
1418            &ctx,
1419            "SELECT EXTRACT(day FROM date) AS day FROM table_a",
1420            vec![("date", "DAY")],
1421            vec![],
1422        )
1423        .await;
1424    }
1425
1426    #[tokio::test]
1427    async fn extract_day_qualified_column() {
1428        let (_dir, ctx, _) = setup_single_date_table().await;
1429        assert_metadata(
1430            &ctx,
1431            "SELECT EXTRACT(DAY FROM table_a.date) FROM table_a",
1432            vec![("date", "DAY")],
1433            vec![],
1434        )
1435        .await;
1436    }
1437
1438    #[tokio::test]
1439    async fn extract_day_in_avg() {
1440        let (_dir, ctx, _) = setup_single_date_table().await;
1441        assert_metadata(
1442            &ctx,
1443            "SELECT AVG(EXTRACT(DAY FROM date)) AS avg_day FROM table_a",
1444            vec![("date", "DAY")],
1445            vec![],
1446        )
1447        .await;
1448    }
1449
1450    #[tokio::test]
1451    async fn extract_day_in_expression() {
1452        let (_dir, ctx, _) = setup_single_date_table().await;
1453        assert_metadata(
1454            &ctx,
1455            "SELECT AVG(EXTRACT(DAY FROM date) + 1) AS avg_day FROM table_a",
1456            vec![("date", "DAY")],
1457            vec![],
1458        )
1459        .await;
1460    }
1461
1462    #[tokio::test]
1463    async fn extract_day_in_subqueries() {
1464        let (_dir, ctx, _) = setup_single_date_table().await;
1465        assert_metadata(
1466            &ctx,
1467            "SELECT (SELECT MAX(EXTRACT(DAY FROM date)) FROM table_a) AS max_day, (SELECT MIN(EXTRACT(DAY FROM date)) FROM table_a) AS min_day",
1468            vec![("date", "DAY")],
1469            vec![],
1470        )
1471        .await;
1472    }
1473
1474    // ─────────────────────────────────────────────────────────────────────────────
1475    // Date extraction tests - multiple components
1476    // ─────────────────────────────────────────────────────────────────────────────
1477
1478    #[tokio::test]
1479    async fn extract_day_and_month() {
1480        let (_dir, ctx, _) = setup_single_date_table().await;
1481        assert_metadata(
1482            &ctx,
1483            "SELECT EXTRACT(DAY FROM date) AS day, EXTRACT(MONTH FROM date) AS month FROM table_a",
1484            vec![("date", "MONTH,DAY")],
1485            vec![],
1486        )
1487        .await;
1488    }
1489
1490    #[tokio::test]
1491    async fn extract_day_and_month_subqueries() {
1492        let (_dir, ctx, _) = setup_single_date_table().await;
1493        assert_metadata(
1494            &ctx,
1495            "SELECT (SELECT MAX(EXTRACT(DAY FROM date)) FROM table_a) AS max_day, (SELECT MIN(EXTRACT(Month FROM date)) FROM table_a) AS min_day",
1496            vec![("date", "MONTH,DAY")],
1497            vec![],
1498        )
1499        .await;
1500    }
1501
1502    // ─────────────────────────────────────────────────────────────────────────────
1503    // Date extraction tests - multi table (joins)
1504    // ─────────────────────────────────────────────────────────────────────────────
1505
1506    #[tokio::test]
1507    async fn extract_from_joined_tables() {
1508        let (_dir, ctx, _) = setup_dual_date_tables().await;
1509        // Both tables have "date" column - metadata HashMap stores by column name only,
1510        // so we verify at least one extraction is present (iteration order determines which)
1511        let df = ctx
1512            .sql("SELECT EXTRACT(YEAR FROM table_a.date) AS year, EXTRACT(DAY FROM table_b.date) AS day FROM table_a INNER JOIN table_b ON table_a.event_ts = table_b.event_ts")
1513            .await
1514            .unwrap();
1515        let (state, plan) = df.into_parts();
1516        let optimized = state.optimize(&plan).unwrap();
1517        let physical_plan = state.create_physical_plan(&optimized).await.unwrap();
1518        let metadata = extract_field_metadata(&physical_plan, DATE_MAPPING_METADATA_KEY);
1519
1520        // Both tables' date columns should have extraction metadata
1521        assert!(
1522            metadata.contains_key("date"),
1523            "date column should have extraction metadata"
1524        );
1525        let value = metadata.get("date").unwrap();
1526        assert!(
1527            value == "YEAR" || value == "DAY",
1528            "expected YEAR or DAY, got {}",
1529            value
1530        );
1531    }
1532
1533    // ─────────────────────────────────────────────────────────────────────────────
1534    // Date extraction tests - no metadata (inconsistent usage)
1535    // ─────────────────────────────────────────────────────────────────────────────
1536
1537    #[tokio::test]
1538    async fn no_extraction_with_interval_arithmetic() {
1539        let (_dir, ctx, _) = setup_single_date_table().await;
1540        assert_metadata(
1541            &ctx,
1542            "SELECT EXTRACT(DAY FROM date + INTERVAL '1 day') AS day FROM table_a",
1543            vec![],
1544            vec![],
1545        )
1546        .await;
1547    }
1548
1549    #[tokio::test]
1550    async fn no_extraction_when_column_used_directly() {
1551        let (_dir, ctx, _) = setup_single_date_table().await;
1552        assert_metadata(&ctx, "SELECT date FROM table_a", vec![], vec![]).await;
1553    }
1554
1555    #[tokio::test]
1556    async fn no_extraction_when_used_in_join_condition() {
1557        let (_dir, ctx, _) = setup_dual_date_tables().await;
1558        assert_metadata(
1559            &ctx,
1560            "SELECT EXTRACT(DAY FROM table_a.date) AS day FROM table_a INNER JOIN table_b ON table_a.date = table_b.date",
1561            vec![],
1562            vec![],
1563        )
1564        .await;
1565    }
1566
1567    #[tokio::test]
1568    async fn timestamp_extraction_supported() {
1569        let (_dir, ctx, _) = setup_single_date_table().await;
1570        assert_metadata(
1571            &ctx,
1572            "SELECT EXTRACT(YEAR FROM event_ts) AS year FROM table_a",
1573            vec![("event_ts", "YEAR")],
1574            vec![],
1575        )
1576        .await;
1577    }
1578
1579    #[tokio::test]
1580    async fn timestamp_dow_extraction() {
1581        let (_dir, ctx, _) = setup_single_date_table().await;
1582        assert_metadata(
1583            &ctx,
1584            "SELECT date_part('dow', event_ts) AS dow FROM table_a",
1585            vec![("event_ts", "DOW")],
1586            vec![],
1587        )
1588        .await;
1589    }
1590
1591    // ─────────────────────────────────────────────────────────────────────────────
1592    // Date extraction - metadata isolation test
1593    // ─────────────────────────────────────────────────────────────────────────────
1594
1595    #[tokio::test]
1596    async fn metadata_only_on_extracted_fields() {
1597        let (_dir, ctx, _) = setup_single_date_table().await;
1598        // Only 'date' should have metadata, not 'event_ts'
1599        assert_metadata(
1600            &ctx,
1601            "SELECT EXTRACT(YEAR FROM date) AS year, event_ts FROM table_a",
1602            vec![("date", "YEAR")],
1603            vec![],
1604        )
1605        .await;
1606    }
1607
1608    // ─────────────────────────────────────────────────────────────────────────────
1609    // Variant extraction tests - basic
1610    // ─────────────────────────────────────────────────────────────────────────────
1611
1612    #[tokio::test]
1613    async fn variant_get_single_path() {
1614        let (_dir, ctx, _) = setup_variant_table().await;
1615        assert_metadata(
1616            &ctx,
1617            "SELECT variant_to_json(variant_get(data, 'name')) FROM variants_test",
1618            vec![],
1619            vec!["name"],
1620        )
1621        .await;
1622    }
1623
1624    #[tokio::test]
1625    async fn variant_get_duplicate_paths() {
1626        let (_dir, ctx, _) = setup_variant_table().await;
1627        assert_metadata(
1628            &ctx,
1629            "SELECT variant_get(data, 'name'), variant_get(data, 'name') AS name2 FROM variants_test",
1630            vec![],
1631            vec!["name"],
1632        )
1633        .await;
1634    }
1635
1636    #[tokio::test]
1637    async fn variant_get_with_to_json() {
1638        let (_dir, ctx, _) = setup_variant_table().await;
1639        assert_metadata(
1640            &ctx,
1641            "SELECT variant_to_json(variant_get(data, 'age')), variant_to_json(variant_get(data, 'age')) AS age2 FROM variants_test",
1642            vec![],
1643            vec!["age"],
1644        )
1645        .await;
1646    }
1647
1648    #[tokio::test]
1649    async fn variant_get_qualified_column() {
1650        let (_dir, ctx, _) = setup_variant_table().await;
1651        assert_metadata(
1652            &ctx,
1653            "SELECT variant_get(variants_test.data, 'name') as name1, variant_get(variants_test.data, 'name') as name2 FROM variants_test",
1654            vec![],
1655            vec!["name"],
1656        )
1657        .await;
1658    }
1659
1660    #[tokio::test]
1661    async fn variant_get_in_aggregates() {
1662        let (_dir, ctx, _) = setup_variant_table().await;
1663        assert_metadata(
1664            &ctx,
1665            "SELECT COUNT(variant_get(data, 'age')), MAX(variant_get(data, 'age')) FROM variants_test",
1666            vec![],
1667            vec!["age"],
1668        )
1669        .await;
1670    }
1671
1672    #[tokio::test]
1673    async fn variant_get_with_where_clause() {
1674        let (_dir, ctx, _) = setup_variant_table().await;
1675        assert_metadata(
1676            &ctx,
1677            "SELECT variant_get(data, 'name') FROM variants_test WHERE variant_get(data, 'name') IS NOT NULL",
1678            vec![],
1679            vec!["name"],
1680        )
1681        .await;
1682    }
1683
1684    #[tokio::test]
1685    async fn variant_get_in_subqueries() {
1686        let (_dir, ctx, _) = setup_variant_table().await;
1687        assert_metadata(
1688            &ctx,
1689            "SELECT (SELECT MAX(variant_get(data, 'name')) FROM variants_test) AS max_name, (SELECT MIN(variant_get(data, 'name')) FROM variants_test) AS min_name",
1690            vec![],
1691            vec!["name"],
1692        )
1693        .await;
1694    }
1695
1696    #[tokio::test]
1697    async fn variant_get_multiple_paths() {
1698        let (_dir, ctx, _) = setup_variant_table().await;
1699        assert_metadata(
1700            &ctx,
1701            "SELECT variant_get(data, 'name'), variant_get(data, 'date') FROM variants_test",
1702            vec![],
1703            vec!["date", "name"],
1704        )
1705        .await;
1706    }
1707
1708    #[tokio::test]
1709    async fn variant_get_nested() {
1710        let (_dir, ctx, _) = setup_variant_table().await;
1711        assert_metadata(
1712            &ctx,
1713            "SELECT variant_get(variant_get(data, 'name'), 'age') FROM variants_test",
1714            vec![],
1715            vec!["name"],
1716        )
1717        .await;
1718    }
1719
1720    #[tokio::test]
1721    async fn variant_get_conflicting_types_no_metadata() {
1722        let (_dir, ctx, _) = setup_variant_table().await;
1723        // Same path with different type hints - should not produce metadata
1724        assert_metadata(
1725            &ctx,
1726            "SELECT variant_get(data, 'name', 'Utf8'), variant_get(data, 'name') FROM variants_test",
1727            vec![],
1728            vec![],
1729        )
1730        .await;
1731    }
1732
1733    // ─────────────────────────────────────────────────────────────────────────────
1734    // Variant extraction tests - type metadata
1735    // ─────────────────────────────────────────────────────────────────────────────
1736
1737    #[tokio::test]
1738    async fn variant_get_type_hint_propagated() {
1739        let (_dir, ctx, _) = setup_variant_table().await;
1740
1741        let df = ctx
1742            .sql("SELECT variant_get(data, 'name', 'Utf8') FROM variants_test")
1743            .await
1744            .unwrap();
1745        let (state, plan) = df.into_parts();
1746        let optimized = state.optimize(&plan).unwrap();
1747        let physical_plan = state.create_physical_plan(&optimized).await.unwrap();
1748
1749        let metadata = extract_field_metadata(&physical_plan, VARIANT_MAPPING_METADATA_KEY);
1750
1751        let entries = metadata
1752            .get("data")
1753            .map(|value| parse_variant_metadata(value))
1754            .unwrap_or_default();
1755        let entry = entries
1756            .iter()
1757            .find(|entry| entry.path == "name")
1758            .expect("variant metadata entry for name");
1759        assert_eq!(entry.data_type.as_deref(), Some("Utf8"));
1760    }
1761
1762    #[tokio::test]
1763    async fn variant_get_conflicting_types_in_filter() {
1764        let (_dir, ctx, _) = setup_variant_table().await;
1765        assert_metadata(
1766            &ctx,
1767            "SELECT variant_to_json(variant_get(data, 'name')) FROM variants_test WHERE variant_get(data, 'name', 'Utf8') = 'Bob'",
1768            vec![],
1769            vec![],
1770        )
1771        .await;
1772    }
1773
1774    // ─────────────────────────────────────────────────────────────────────────────
1775    // Variant extraction tests - edge cases
1776    // ─────────────────────────────────────────────────────────────────────────────
1777
1778    #[tokio::test]
1779    async fn variant_get_multiple_paths_with_types() {
1780        let (_dir, ctx, _) = setup_variant_table().await;
1781        assert_metadata(
1782            &ctx,
1783            "SELECT variant_get(data, 'did', 'Utf8') as user_id,
1784             MAX(TO_TIMESTAMP_MICROS(variant_get(data, 'time_us', 'Int64'))) - MIN(TO_TIMESTAMP_MICROS(variant_get(data, 'time_us', 'Int64')))
1785            FROM variants_test GROUP BY user_id",
1786            vec![],
1787            vec!["did", "time_us"],
1788        )
1789        .await;
1790    }
1791
1792    // ─────────────────────────────────────────────────────────────────────────────
1793    // Mixed date extract and variant tests
1794    // ─────────────────────────────────────────────────────────────────────────────
1795
1796    #[tokio::test]
1797    async fn mixed_date_and_variant_basic() {
1798        let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1799        assert_metadata(
1800            &ctx,
1801            "SELECT EXTRACT(DAY FROM table_a.date) AS day, variant_get(variants_test.data, 'name') AS name FROM table_a CROSS JOIN variants_test",
1802            vec![("date", "DAY")],
1803            vec!["name"],
1804        )
1805        .await;
1806    }
1807
1808    #[tokio::test]
1809    async fn mixed_multiple_date_and_variant() {
1810        let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1811        assert_metadata(
1812            &ctx,
1813            "SELECT EXTRACT(YEAR FROM table_a.date) AS year, EXTRACT(MONTH FROM table_a.date_copy) AS month, variant_get(variants_test.data, 'name') AS name, variant_get(variants_test.data, 'age') AS age FROM table_a CROSS JOIN variants_test",
1814            vec![("date", "YEAR"), ("date_copy", "MONTH")],
1815            vec!["age", "name"],
1816        )
1817        .await;
1818    }
1819
1820    #[tokio::test]
1821    async fn mixed_date_in_where_clause() {
1822        let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1823        assert_metadata(
1824            &ctx,
1825            "SELECT variant_get(variants_test.data, 'name') AS name FROM variants_test CROSS JOIN table_a WHERE EXTRACT(DAY FROM table_a.date) > 1",
1826            vec![("date", "DAY")],
1827            vec!["name"],
1828        )
1829        .await;
1830    }
1831
1832    #[tokio::test]
1833    async fn mixed_variant_in_where_clause() {
1834        let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1835        assert_metadata(
1836            &ctx,
1837            "SELECT EXTRACT(DAY FROM table_a.date) AS day FROM table_a CROSS JOIN variants_test WHERE variant_get(variants_test.data, 'name') IS NOT NULL",
1838            vec![("date", "DAY")],
1839            vec!["name"],
1840        )
1841        .await;
1842    }
1843
1844    #[tokio::test]
1845    async fn mixed_date_with_variant_subquery() {
1846        let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1847        assert_metadata(
1848            &ctx,
1849            "SELECT EXTRACT(YEAR FROM table_a.date) AS year, (SELECT variant_get(variants_test.data, 'name') FROM variants_test LIMIT 1) AS name FROM table_a",
1850            vec![("date", "YEAR")],
1851            vec!["name"],
1852        )
1853        .await;
1854    }
1855
1856    #[tokio::test]
1857    async fn mixed_in_aggregates() {
1858        let (_dir, ctx, _) = setup_date_and_variant_tables().await;
1859        assert_metadata(
1860            &ctx,
1861            "SELECT AVG(EXTRACT(DAY FROM table_a.date)) AS avg_day, COUNT(variant_get(variants_test.data, 'name')) AS name_count FROM table_a CROSS JOIN variants_test",
1862            vec![("date", "DAY")],
1863            vec!["name"],
1864        )
1865        .await;
1866    }
1867}