use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use datafusion::common::Result;
use datafusion::common::config::ConfigOptions;
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion::logical_expr::LogicalPlan;
use datafusion::optimizer::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use samkhya_core::stats::ColumnStats;
use crate::physical_plan::SamkhyaStatsExec;
use crate::stats_provider::to_datafusion_column_statistics;
#[derive(Debug, Default)]
pub struct SamkhyaOptimizerRule {
samkhya_leaves_seen: AtomicUsize,
}
impl Clone for SamkhyaOptimizerRule {
fn clone(&self) -> Self {
Self {
samkhya_leaves_seen: AtomicUsize::new(self.samkhya_leaves_seen.load(Ordering::SeqCst)),
}
}
}
impl SamkhyaOptimizerRule {
pub fn new() -> Self {
Self::default()
}
pub fn arc() -> Arc<Self> {
Arc::new(Self::new())
}
pub fn samkhya_leaves_seen(&self) -> usize {
self.samkhya_leaves_seen.load(Ordering::SeqCst)
}
}
impl OptimizerRule for SamkhyaOptimizerRule {
fn name(&self) -> &str {
"samkhya_cardinality_correction"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let mut scan_count = 0usize;
plan.apply(|node| {
if let LogicalPlan::TableScan(scan) = node {
scan_count += 1;
let n_cols = scan.projected_schema.fields().len();
for col_idx in 0..n_cols {
let corrected = compute_corrected_stats(&scan.table_name.to_string(), col_idx);
let _df_stats = to_datafusion_column_statistics(&corrected);
}
}
Ok(TreeNodeRecursion::Continue)
})?;
let _ = scan_count;
Ok(Transformed::no(plan))
}
}
impl PhysicalOptimizerRule for SamkhyaOptimizerRule {
fn name(&self) -> &str {
"samkhya_cardinality_correction"
}
fn schema_check(&self) -> bool {
true
}
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut seen = 0usize;
plan.apply(|node| {
if node.as_any().downcast_ref::<SamkhyaStatsExec>().is_some() {
seen += 1;
}
Ok(TreeNodeRecursion::Continue)
})?;
self.samkhya_leaves_seen.store(seen, Ordering::SeqCst);
Ok(plan)
}
}
pub fn compute_corrected_stats(_table: &str, _col_idx: usize) -> ColumnStats {
ColumnStats::new()
.with_row_count(1_000)
.with_distinct_count(100)
.with_null_count(0)
.with_upper_bound(10_000)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rule_has_stable_name() {
let r = SamkhyaOptimizerRule::new();
assert_eq!(
<SamkhyaOptimizerRule as OptimizerRule>::name(&r),
"samkhya_cardinality_correction"
);
assert_eq!(
<SamkhyaOptimizerRule as PhysicalOptimizerRule>::name(&r),
"samkhya_cardinality_correction"
);
assert!(r.supports_rewrite());
assert!(matches!(r.apply_order(), Some(ApplyOrder::BottomUp)));
}
#[test]
fn placeholder_stats_are_populated() {
let s = compute_corrected_stats("t", 0);
assert_eq!(s.row_count, Some(1_000));
assert_eq!(s.distinct_count, Some(100));
assert_eq!(s.upper_bound_rows, Some(10_000));
}
#[test]
fn leaves_seen_starts_at_zero() {
let r = SamkhyaOptimizerRule::new();
assert_eq!(r.samkhya_leaves_seen(), 0);
}
}