use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::common::Statistics;
use datafusion::common::stats::Precision;
use datafusion::logical_expr::TableType;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::Expr;
use std::any::Any;
use std::sync::Arc;
use super::distributed::DistributedExecutor;
use super::distributed_exec::DistributedSpireExec;
use super::exec::SpireExec;
use super::pool::ConnectionPool;
use super::pruning::{KeyBounds, extract_key_bounds};
use super::routing::RegionRouter;
use super::statistics::StatisticsProvider;
use super::topology::ClusterTopology;
use datafusion::catalog::Session;
pub struct SpireProvider {
table_name: String,
schema: SchemaRef,
executor: Arc<DistributedExecutor>,
pk_column: String,
stats_provider: Arc<StatisticsProvider>,
connection_pool: Arc<ConnectionPool>,
region_router: Arc<RegionRouter>,
cluster_topology: Arc<ClusterTopology>,
}
impl std::fmt::Debug for SpireProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SpireProvider")
.field("table_name", &self.table_name)
.field("pk_column", &self.pk_column)
.finish()
}
}
impl SpireProvider {
#[allow(clippy::too_many_arguments)]
pub fn with_distributed(
table_name: String,
schema: SchemaRef,
executor: Arc<DistributedExecutor>,
pk_column: String,
stats_provider: Arc<StatisticsProvider>,
connection_pool: Arc<ConnectionPool>,
region_router: Arc<RegionRouter>,
cluster_topology: Arc<ClusterTopology>,
) -> Self {
Self {
table_name,
schema,
executor,
pk_column,
stats_provider,
connection_pool,
region_router,
cluster_topology,
}
}
#[allow(dead_code)]
pub fn is_distributed(&self) -> bool {
true
}
async fn get_matching_regions(
&self,
key_bounds: &KeyBounds,
) -> Vec<super::routing::RegionInfo> {
let regions_arc = match self.region_router.get_table_regions(&self.table_name).await {
Ok(r) => r,
Err(e) => {
log::warn!("Failed to get regions for {}: {}", self.table_name, e);
return vec![];
}
};
let regions = (*regions_arc).clone();
if !key_bounds.is_bounded() {
return regions;
}
regions
.into_iter()
.filter(|r| {
let matches_start = key_bounds.start_key.as_ref().is_none_or(|start| {
r.end_key.is_empty() || r.end_key.as_slice() > start.as_slice()
});
let matches_end = key_bounds.end_key.as_ref().is_none_or(|end| {
r.start_key.is_empty() || r.start_key.as_slice() < end.as_slice()
});
matches_start && matches_end
})
.collect()
}
}
#[async_trait]
impl TableProvider for SpireProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
fn statistics(&self) -> Option<Statistics> {
if let Some(cached) = self.stats_provider.get_cached_stats(&self.table_name) {
log::debug!(
"Using cached statistics for '{}': {} rows, {} bytes",
self.table_name,
cached.row_count,
cached.size_bytes
);
let column_statistics: Vec<datafusion::common::ColumnStatistics> = self
.schema
.fields()
.iter()
.map(|field| {
if let Some(col_stats) = cached.column_stats.get(field.name()) {
datafusion::common::ColumnStatistics {
null_count: Precision::Exact(col_stats.null_count as usize),
distinct_count: Precision::Exact(col_stats.distinct_count as usize),
min_value: col_stats
.min_value
.clone()
.map(Precision::Exact)
.unwrap_or(Precision::Absent),
max_value: col_stats
.max_value
.clone()
.map(Precision::Exact)
.unwrap_or(Precision::Absent),
sum_value: Precision::Absent, byte_size: Precision::Absent, }
} else {
datafusion::common::ColumnStatistics::new_unknown()
}
})
.collect();
return Some(Statistics {
num_rows: Precision::Exact(cached.row_count as usize),
total_byte_size: Precision::Exact(cached.size_bytes as usize),
column_statistics,
});
}
log::debug!("No cached statistics for table '{}'", self.table_name);
None
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let key_bounds = extract_key_bounds(filters, &self.pk_column);
let matching_regions = self.get_matching_regions(&key_bounds).await;
if matching_regions.len() == 1 {
let region = &matching_regions[0];
let leader_id = region.leader_store_id;
if let Some(addr) = self.cluster_topology.get_store_address(leader_id) {
match self.connection_pool.get_data_access_client(&addr).await {
Ok(client) => {
log::debug!(
"Using single-shard SpireExec for table '{}' (region {}, leader {})",
self.table_name,
region.region_id,
leader_id
);
return Ok(Arc::new(SpireExec::new(
client,
self.table_name.clone(),
self.schema.clone(),
projection.cloned(),
filters.to_vec(),
limit,
)));
}
Err(e) => {
log::warn!(
"Failed to get client for leader {}: {}, falling back to distributed",
leader_id,
e
);
}
}
}
}
log::debug!(
"Using distributed execution for table '{}' ({} regions, {} filters)",
self.table_name,
matching_regions.len(),
filters.len()
);
Ok(Arc::new(DistributedSpireExec::new(
self.executor.clone(),
self.table_name.clone(),
self.schema.clone(),
projection,
filters,
&self.pk_column,
limit,
)))
}
}