pub mod metadata_table;
pub mod table_provider_factory;
use std::any::Any;
use std::num::NonZeroUsize;
use std::sync::Arc;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::catalog::Session;
use datafusion::common::DataFusionError;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result as DFResult;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use iceberg::arrow::schema_to_arrow_schema;
use iceberg::inspect::MetadataTableType;
use iceberg::spec::TableProperties;
use iceberg::table::Table;
use iceberg::{Catalog, Error, ErrorKind, NamespaceIdent, Result, TableIdent};
use metadata_table::IcebergMetadataTableProvider;
use crate::error::to_datafusion_error;
use crate::physical_plan::commit::IcebergCommitExec;
use crate::physical_plan::project::project_with_partition;
use crate::physical_plan::repartition::repartition;
use crate::physical_plan::scan::IcebergTableScan;
use crate::physical_plan::sort::sort_by_partition;
use crate::physical_plan::write::IcebergWriteExec;
#[derive(Debug, Clone)]
pub struct IcebergTableProvider {
catalog: Arc<dyn Catalog>,
table_ident: TableIdent,
schema: ArrowSchemaRef,
}
impl IcebergTableProvider {
pub(crate) async fn try_new(
catalog: Arc<dyn Catalog>,
namespace: NamespaceIdent,
name: impl Into<String>,
) -> Result<Self> {
let table_ident = TableIdent::new(namespace, name.into());
let table = catalog.load_table(&table_ident).await?;
let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?);
Ok(IcebergTableProvider {
catalog,
table_ident,
schema,
})
}
pub(crate) async fn metadata_table(
&self,
r#type: MetadataTableType,
) -> Result<IcebergMetadataTableProvider> {
let table = self.catalog.load_table(&self.table_ident).await?;
Ok(IcebergMetadataTableProvider { table, r#type })
}
}
#[async_trait]
impl TableProvider for IcebergTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> ArrowSchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
let table = self
.catalog
.load_table(&self.table_ident)
.await
.map_err(to_datafusion_error)?;
Ok(Arc::new(IcebergTableScan::new(
table,
None, self.schema.clone(),
projection,
filters,
limit,
)))
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> DFResult<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()])
}
async fn insert_into(
&self,
state: &dyn Session,
input: Arc<dyn ExecutionPlan>,
_insert_op: InsertOp,
) -> DFResult<Arc<dyn ExecutionPlan>> {
let table = self
.catalog
.load_table(&self.table_ident)
.await
.map_err(to_datafusion_error)?;
let partition_spec = table.metadata().default_partition_spec();
let plan_with_partition = if !partition_spec.is_unpartitioned() {
project_with_partition(input, &table)?
} else {
input
};
let target_partitions =
NonZeroUsize::new(state.config().target_partitions()).ok_or_else(|| {
DataFusionError::Configuration(
"target_partitions must be greater than 0".to_string(),
)
})?;
let repartitioned_plan =
repartition(plan_with_partition, table.metadata_ref(), target_partitions)?;
let fanout_enabled = table
.metadata()
.properties()
.get(TableProperties::PROPERTY_DATAFUSION_WRITE_FANOUT_ENABLED)
.map(|value| {
value
.parse::<bool>()
.map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
format!(
"Invalid value for {}, expected 'true' or 'false'",
TableProperties::PROPERTY_DATAFUSION_WRITE_FANOUT_ENABLED
),
)
.with_source(e)
})
.map_err(to_datafusion_error)
})
.transpose()?
.unwrap_or(TableProperties::PROPERTY_DATAFUSION_WRITE_FANOUT_ENABLED_DEFAULT);
let write_input = if fanout_enabled {
repartitioned_plan
} else {
sort_by_partition(repartitioned_plan)?
};
let write_plan = Arc::new(IcebergWriteExec::new(
table.clone(),
write_input,
self.schema.clone(),
));
let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(write_plan));
Ok(Arc::new(IcebergCommitExec::new(
table,
self.catalog.clone(),
coalesce_partitions,
self.schema.clone(),
)))
}
}
#[derive(Debug, Clone)]
pub struct IcebergStaticTableProvider {
table: Table,
snapshot_id: Option<i64>,
schema: ArrowSchemaRef,
}
impl IcebergStaticTableProvider {
pub async fn try_new_from_table(table: Table) -> Result<Self> {
let schema = Arc::new(schema_to_arrow_schema(table.metadata().current_schema())?);
Ok(IcebergStaticTableProvider {
table,
snapshot_id: None,
schema,
})
}
pub async fn try_new_from_table_snapshot(table: Table, snapshot_id: i64) -> Result<Self> {
let snapshot = table
.metadata()
.snapshot_by_id(snapshot_id)
.ok_or_else(|| {
Error::new(
ErrorKind::Unexpected,
format!(
"snapshot id {snapshot_id} not found in table {}",
table.identifier().name()
),
)
})?;
let table_schema = snapshot.schema(table.metadata())?;
let schema = Arc::new(schema_to_arrow_schema(&table_schema)?);
Ok(IcebergStaticTableProvider {
table,
snapshot_id: Some(snapshot_id),
schema,
})
}
}
#[async_trait]
impl TableProvider for IcebergStaticTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> ArrowSchemaRef {
self.schema.clone()
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(IcebergTableScan::new(
self.table.clone(),
self.snapshot_id,
self.schema.clone(),
projection,
filters,
limit,
)))
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> DFResult<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()])
}
async fn insert_into(
&self,
_state: &dyn Session,
_input: Arc<dyn ExecutionPlan>,
_insert_op: InsertOp,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Err(to_datafusion_error(Error::new(
ErrorKind::FeatureUnsupported,
"Write operations are not supported on IcebergStaticTableProvider. \
Use IcebergTableProvider with a catalog for write support."
.to_string(),
)))
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use datafusion::common::Column;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use iceberg::io::FileIO;
use iceberg::memory::{MEMORY_CATALOG_WAREHOUSE, MemoryCatalogBuilder};
use iceberg::spec::{NestedField, PrimitiveType, Schema, Type};
use iceberg::table::{StaticTable, Table};
use iceberg::{Catalog, CatalogBuilder, NamespaceIdent, TableCreation, TableIdent};
use tempfile::TempDir;
use super::*;
async fn get_test_table_from_metadata_file() -> Table {
let metadata_file_name = "TableMetadataV2Valid.json";
let metadata_file_path = format!(
"{}/tests/test_data/{}",
env!("CARGO_MANIFEST_DIR"),
metadata_file_name
);
let file_io = FileIO::new_with_fs();
let static_identifier = TableIdent::from_strs(["static_ns", "static_table"]).unwrap();
let static_table =
StaticTable::from_metadata_file(&metadata_file_path, static_identifier, file_io)
.await
.unwrap();
static_table.into_table()
}
async fn get_test_catalog_and_table() -> (Arc<dyn Catalog>, NamespaceIdent, String, TempDir) {
let temp_dir = TempDir::new().unwrap();
let warehouse_path = temp_dir.path().to_str().unwrap().to_string();
let catalog = MemoryCatalogBuilder::default()
.load(
"memory",
HashMap::from([(MEMORY_CATALOG_WAREHOUSE.to_string(), warehouse_path.clone())]),
)
.await
.unwrap();
let namespace = NamespaceIdent::new("test_ns".to_string());
catalog
.create_namespace(&namespace, HashMap::new())
.await
.unwrap();
let schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap();
let table_creation = TableCreation::builder()
.name("test_table".to_string())
.location(format!("{warehouse_path}/test_table"))
.schema(schema)
.properties(HashMap::new())
.build();
catalog
.create_table(&namespace, table_creation)
.await
.unwrap();
(
Arc::new(catalog),
namespace,
"test_table".to_string(),
temp_dir,
)
}
#[tokio::test]
async fn test_static_provider_from_table() {
let table = get_test_table_from_metadata_file().await;
let table_provider = IcebergStaticTableProvider::try_new_from_table(table.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("mytable", Arc::new(table_provider))
.unwrap();
let df = ctx.sql("SELECT * FROM mytable").await.unwrap();
let df_schema = df.schema();
let df_columns = df_schema.fields();
assert_eq!(df_columns.len(), 3);
let x_column = df_columns.first().unwrap();
let column_data = format!(
"{:?}:{:?}",
x_column.name(),
x_column.data_type().to_string()
);
assert_eq!(column_data, "\"x\":\"Int64\"");
let has_column = df_schema.has_column(&Column::from_name("z"));
assert!(has_column);
}
#[tokio::test]
async fn test_static_provider_from_snapshot() {
let table = get_test_table_from_metadata_file().await;
let snapshot_id = table.metadata().snapshots().next().unwrap().snapshot_id();
let table_provider =
IcebergStaticTableProvider::try_new_from_table_snapshot(table.clone(), snapshot_id)
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("mytable", Arc::new(table_provider))
.unwrap();
let df = ctx.sql("SELECT * FROM mytable").await.unwrap();
let df_schema = df.schema();
let df_columns = df_schema.fields();
assert_eq!(df_columns.len(), 3);
let x_column = df_columns.first().unwrap();
let column_data = format!(
"{:?}:{:?}",
x_column.name(),
x_column.data_type().to_string()
);
assert_eq!(column_data, "\"x\":\"Int64\"");
let has_column = df_schema.has_column(&Column::from_name("z"));
assert!(has_column);
}
#[tokio::test]
async fn test_static_provider_rejects_writes() {
let table = get_test_table_from_metadata_file().await;
let table_provider = IcebergStaticTableProvider::try_new_from_table(table.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("mytable", Arc::new(table_provider))
.unwrap();
let result = ctx.sql("INSERT INTO mytable VALUES (1, 2, 3)").await;
assert!(
result.is_err() || {
let df = result.unwrap();
df.collect().await.is_err()
}
);
}
#[tokio::test]
async fn test_static_provider_scan() {
let table = get_test_table_from_metadata_file().await;
let table_provider = IcebergStaticTableProvider::try_new_from_table(table.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("mytable", Arc::new(table_provider))
.unwrap();
let df = ctx.sql("SELECT count(*) FROM mytable").await.unwrap();
let physical_plan = df.create_physical_plan().await;
assert!(physical_plan.is_ok());
}
#[tokio::test]
async fn test_catalog_backed_provider_creation() {
let (catalog, namespace, table_name, _temp_dir) = get_test_catalog_and_table().await;
let provider =
IcebergTableProvider::try_new(catalog.clone(), namespace.clone(), table_name.clone())
.await
.unwrap();
let schema = provider.schema();
assert_eq!(schema.fields().len(), 2);
assert_eq!(schema.field(0).name(), "id");
assert_eq!(schema.field(1).name(), "name");
}
#[tokio::test]
async fn test_catalog_backed_provider_scan() {
let (catalog, namespace, table_name, _temp_dir) = get_test_catalog_and_table().await;
let provider =
IcebergTableProvider::try_new(catalog.clone(), namespace.clone(), table_name.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("test_table", Arc::new(provider))
.unwrap();
let df = ctx.sql("SELECT * FROM test_table").await.unwrap();
let df_schema = df.schema();
assert_eq!(df_schema.fields().len(), 2);
assert_eq!(df_schema.field(0).name(), "id");
assert_eq!(df_schema.field(1).name(), "name");
let physical_plan = df.create_physical_plan().await;
assert!(physical_plan.is_ok());
}
#[tokio::test]
async fn test_catalog_backed_provider_insert() {
let (catalog, namespace, table_name, _temp_dir) = get_test_catalog_and_table().await;
let provider =
IcebergTableProvider::try_new(catalog.clone(), namespace.clone(), table_name.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("test_table", Arc::new(provider))
.unwrap();
let result = ctx.sql("INSERT INTO test_table VALUES (1, 'test')").await;
assert!(result.is_ok());
let df = result.unwrap();
let execution_result = df.collect().await;
assert!(execution_result.is_ok());
}
#[tokio::test]
async fn test_physical_input_schema_consistent_with_logical_input_schema() {
let (catalog, namespace, table_name, _temp_dir) = get_test_catalog_and_table().await;
let provider =
IcebergTableProvider::try_new(catalog.clone(), namespace.clone(), table_name.clone())
.await
.unwrap();
let ctx = SessionContext::new();
ctx.register_table("test_table", Arc::new(provider))
.unwrap();
let df = ctx.sql("SELECT id, name FROM test_table").await.unwrap();
let logical_schema = df.schema().clone();
let physical_plan = df.create_physical_plan().await.unwrap();
let physical_schema = physical_plan.schema();
assert_eq!(
logical_schema.fields().len(),
physical_schema.fields().len()
);
for (logical_field, physical_field) in logical_schema
.fields()
.iter()
.zip(physical_schema.fields().iter())
{
assert_eq!(logical_field.name(), physical_field.name());
assert_eq!(logical_field.data_type(), physical_field.data_type());
}
}
async fn get_partitioned_test_catalog_and_table(
fanout_enabled: Option<bool>,
) -> (Arc<dyn Catalog>, NamespaceIdent, String, TempDir) {
use iceberg::spec::{Transform, UnboundPartitionSpec};
let temp_dir = TempDir::new().unwrap();
let warehouse_path = temp_dir.path().to_str().unwrap().to_string();
let catalog = MemoryCatalogBuilder::default()
.load(
"memory",
HashMap::from([(MEMORY_CATALOG_WAREHOUSE.to_string(), warehouse_path.clone())]),
)
.await
.unwrap();
let namespace = NamespaceIdent::new("test_ns".to_string());
catalog
.create_namespace(&namespace, HashMap::new())
.await
.unwrap();
let schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "category", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap();
let partition_spec = UnboundPartitionSpec::builder()
.with_spec_id(0)
.add_partition_field(2, "category", Transform::Identity)
.unwrap()
.build();
let mut properties = HashMap::new();
if let Some(enabled) = fanout_enabled {
properties.insert(
iceberg::spec::TableProperties::PROPERTY_DATAFUSION_WRITE_FANOUT_ENABLED
.to_string(),
enabled.to_string(),
);
}
let table_creation = TableCreation::builder()
.name("partitioned_table".to_string())
.location(format!("{warehouse_path}/partitioned_table"))
.schema(schema)
.partition_spec(partition_spec)
.properties(properties)
.build();
catalog
.create_table(&namespace, table_creation)
.await
.unwrap();
(
Arc::new(catalog),
namespace,
"partitioned_table".to_string(),
temp_dir,
)
}
fn plan_contains_sort(plan: &Arc<dyn ExecutionPlan>) -> bool {
if plan.name() == "SortExec" {
return true;
}
for child in plan.children() {
if plan_contains_sort(child) {
return true;
}
}
false
}
#[tokio::test]
async fn test_insert_plan_fanout_enabled_no_sort() {
use datafusion::datasource::TableProvider;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::physical_plan::empty::EmptyExec;
let (catalog, namespace, table_name, _temp_dir) =
get_partitioned_test_catalog_and_table(Some(true)).await;
let provider =
IcebergTableProvider::try_new(catalog.clone(), namespace.clone(), table_name.clone())
.await
.unwrap();
let ctx = SessionContext::new();
let input_schema = provider.schema();
let input = Arc::new(EmptyExec::new(input_schema)) as Arc<dyn ExecutionPlan>;
let state = ctx.state();
let insert_plan = provider
.insert_into(&state, input, InsertOp::Append)
.await
.unwrap();
assert!(
!plan_contains_sort(&insert_plan),
"Plan should NOT contain SortExec when fanout is enabled"
);
}
#[tokio::test]
async fn test_insert_plan_fanout_disabled_has_sort() {
use datafusion::datasource::TableProvider;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::physical_plan::empty::EmptyExec;
let (catalog, namespace, table_name, _temp_dir) =
get_partitioned_test_catalog_and_table(Some(false)).await;
let provider =
IcebergTableProvider::try_new(catalog.clone(), namespace.clone(), table_name.clone())
.await
.unwrap();
let ctx = SessionContext::new();
let input_schema = provider.schema();
let input = Arc::new(EmptyExec::new(input_schema)) as Arc<dyn ExecutionPlan>;
let state = ctx.state();
let insert_plan = provider
.insert_into(&state, input, InsertOp::Append)
.await
.unwrap();
assert!(
plan_contains_sort(&insert_plan),
"Plan should contain SortExec when fanout is disabled"
);
}
#[tokio::test]
async fn test_limit_pushdown_static_provider() {
use datafusion::datasource::TableProvider;
let table = get_test_table_from_metadata_file().await;
let table_provider = IcebergStaticTableProvider::try_new_from_table(table.clone())
.await
.unwrap();
let ctx = SessionContext::new();
let state = ctx.state();
let scan_plan = table_provider
.scan(&state, None, &[], Some(10))
.await
.unwrap();
let iceberg_scan = scan_plan
.as_any()
.downcast_ref::<IcebergTableScan>()
.expect("Expected IcebergTableScan");
assert_eq!(
iceberg_scan.limit(),
Some(10),
"Limit should be set to 10 in the scan plan"
);
}
#[tokio::test]
async fn test_limit_pushdown_catalog_backed_provider() {
use datafusion::datasource::TableProvider;
let (catalog, namespace, table_name, _temp_dir) = get_test_catalog_and_table().await;
let provider =
IcebergTableProvider::try_new(catalog.clone(), namespace.clone(), table_name.clone())
.await
.unwrap();
let ctx = SessionContext::new();
let state = ctx.state();
let scan_plan = provider.scan(&state, None, &[], Some(5)).await.unwrap();
let iceberg_scan = scan_plan
.as_any()
.downcast_ref::<IcebergTableScan>()
.expect("Expected IcebergTableScan");
assert_eq!(
iceberg_scan.limit(),
Some(5),
"Limit should be set to 5 in the scan plan"
);
}
#[tokio::test]
async fn test_no_limit_pushdown() {
use datafusion::datasource::TableProvider;
let table = get_test_table_from_metadata_file().await;
let table_provider = IcebergStaticTableProvider::try_new_from_table(table.clone())
.await
.unwrap();
let ctx = SessionContext::new();
let state = ctx.state();
let scan_plan = table_provider.scan(&state, None, &[], None).await.unwrap();
let iceberg_scan = scan_plan
.as_any()
.downcast_ref::<IcebergTableScan>()
.expect("Expected IcebergTableScan");
assert_eq!(
iceberg_scan.limit(),
None,
"Limit should be None when not specified"
);
}
}