use std::collections::HashMap;
use std::sync::Arc;
use buoyant_kernel as delta_kernel;
use chrono::{NaiveDate, NaiveDateTime, TimeZone, Utc};
use delta_kernel::arrow::array::{
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array,
Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, RecordBatch, StringArray,
TimestampMicrosecondArray,
};
use delta_kernel::arrow::datatypes::Schema as ArrowSchema;
use delta_kernel::committer::FileSystemCommitter;
use delta_kernel::engine::arrow_conversion::TryIntoArrow as _;
use delta_kernel::expressions::Scalar;
use delta_kernel::schema::{DataType, StructField, StructType};
use delta_kernel::table_features::ColumnMappingMode;
use delta_kernel::transaction::create_table::create_table;
use delta_kernel::transaction::data_layout::DataLayout;
use delta_kernel::Snapshot;
use rstest::rstest;
use test_utils::{read_scan, test_table_setup_mt, write_batch_to_table};
#[rstest]
#[case::cm_none(ColumnMappingMode::None)]
#[case::cm_name(ColumnMappingMode::Name)]
#[case::cm_id(ColumnMappingMode::Id)]
#[tokio::test(flavor = "multi_thread")]
async fn test_write_partitioned_normal_values_roundtrip(
#[case] cm_mode: ColumnMappingMode,
) -> Result<(), Box<dyn std::error::Error>> {
let (_tmp_dir, table_path, snapshot, engine) = setup_and_write(
all_types_schema(),
PARTITION_COLS,
cm_mode,
normal_arrow_columns(),
normal_partition_values()?,
)
.await?;
assert_eq!(snapshot.table_configuration().partition_columns().len(), 13);
let (add, rel_path) = read_single_add(&table_path, 1)?;
match cm_mode {
ColumnMappingMode::None => {
let expected_prefix = "\
p_string=hello/p_int=42/p_long=9876543210/p_short=7/\
p_byte=3/p_float=1.25/p_double=99.99/p_boolean=true/p_date=2025-03-31/\
p_timestamp=2025-03-31T15%3A30%3A00.123456Z/p_decimal=123.45/\
p_binary=Hello/p_timestamp_ntz=2025-03-31%2015%3A30%3A00.123456/";
assert!(
rel_path.starts_with(expected_prefix),
"CM off: relative path mismatch.\n \
expected: {expected_prefix}<uuid>.parquet\n got: {rel_path}"
);
assert!(rel_path.ends_with(".parquet"));
}
ColumnMappingMode::Name | ColumnMappingMode::Id => {
assert_cm_path(&rel_path);
}
}
let pv = add["partitionValues"].as_object().unwrap();
match cm_mode {
ColumnMappingMode::None => {
for (key, val) in EXPECTED_NORMAL_PVS {
assert_eq!(
pv.get(*key).and_then(|v| v.as_str()),
Some(*val),
"partitionValues[{key}] mismatch"
);
}
}
ColumnMappingMode::Name | ColumnMappingMode::Id => {
let logical_schema = snapshot.schema();
for (logical_key, expected_val) in EXPECTED_NORMAL_PVS {
let field = logical_schema.field(logical_key).unwrap();
let physical_key = field.physical_name(cm_mode);
assert_eq!(
pv.get(physical_key).and_then(|v| v.as_str()),
Some(*expected_val),
"partitionValues[{physical_key}] (logical: {logical_key}) mismatch"
);
}
}
}
verify_and_checkpoint(&snapshot, engine, assert_normal_values)?;
Ok(())
}
#[rstest]
#[case::cm_none(ColumnMappingMode::None)]
#[case::cm_name(ColumnMappingMode::Name)]
#[case::cm_id(ColumnMappingMode::Id)]
#[tokio::test(flavor = "multi_thread")]
async fn test_write_partitioned_null_values_roundtrip(
#[case] cm_mode: ColumnMappingMode,
) -> Result<(), Box<dyn std::error::Error>> {
let (_tmp_dir, table_path, snapshot, engine) = setup_and_write(
all_types_schema(),
PARTITION_COLS,
cm_mode,
null_arrow_columns(),
null_partition_values()?,
)
.await?;
let (add, rel_path) = read_single_add(&table_path, 1)?;
match cm_mode {
ColumnMappingMode::None => {
let expected_prefix = hive_prefix(PARTITION_COLS, "__HIVE_DEFAULT_PARTITION__");
assert!(
rel_path.starts_with(&expected_prefix),
"CM off null: relative path mismatch.\n \
expected: {expected_prefix}<uuid>.parquet\n got: {rel_path}"
);
}
ColumnMappingMode::Name | ColumnMappingMode::Id => {
assert_cm_path(&rel_path);
}
}
let pv = add["partitionValues"].as_object().unwrap();
assert_eq!(pv.len(), PARTITION_COLS.len());
for val in pv.values() {
assert!(
val.is_null(),
"all partition values should be null, got: {val}"
);
}
verify_and_checkpoint(&snapshot, engine, assert_all_partition_columns_null)?;
Ok(())
}
fn all_types_schema() -> Arc<StructType> {
Arc::new(
StructType::try_new(vec![
StructField::nullable("value", DataType::INTEGER),
StructField::nullable("p_string", DataType::STRING),
StructField::nullable("p_int", DataType::INTEGER),
StructField::nullable("p_long", DataType::LONG),
StructField::nullable("p_short", DataType::SHORT),
StructField::nullable("p_byte", DataType::BYTE),
StructField::nullable("p_float", DataType::FLOAT),
StructField::nullable("p_double", DataType::DOUBLE),
StructField::nullable("p_boolean", DataType::BOOLEAN),
StructField::nullable("p_date", DataType::DATE),
StructField::nullable("p_timestamp", DataType::TIMESTAMP),
StructField::nullable("p_decimal", DataType::decimal(10, 2).unwrap()),
StructField::nullable("p_binary", DataType::BINARY),
StructField::nullable("p_timestamp_ntz", DataType::TIMESTAMP_NTZ),
])
.unwrap(),
)
}
const PARTITION_COLS: &[&str] = &[
"p_string",
"p_int",
"p_long",
"p_short",
"p_byte",
"p_float",
"p_double",
"p_boolean",
"p_date",
"p_timestamp",
"p_decimal",
"p_binary",
"p_timestamp_ntz",
];
fn normal_arrow_columns() -> Vec<ArrayRef> {
let ts = ts_to_micros("2025-03-31 15:30:00.123456");
vec![
Arc::new(Int32Array::from(vec![1])),
Arc::new(StringArray::from(vec!["hello"])),
Arc::new(Int32Array::from(vec![42])),
Arc::new(Int64Array::from(vec![9_876_543_210i64])),
Arc::new(Int16Array::from(vec![7i16])),
Arc::new(Int8Array::from(vec![3i8])),
Arc::new(Float32Array::from(vec![1.25f32])),
Arc::new(Float64Array::from(vec![99.99f64])),
Arc::new(BooleanArray::from(vec![true])),
Arc::new(Date32Array::from(vec![date_to_days("2025-03-31")])),
ts_array(ts),
decimal_array(12345, 10, 2),
Arc::new(BinaryArray::from_vec(vec![b"Hello"])),
ts_ntz_array(ts),
]
}
fn normal_partition_values() -> Result<HashMap<String, Scalar>, Box<dyn std::error::Error>> {
let ts = ts_to_micros("2025-03-31 15:30:00.123456");
Ok(HashMap::from([
("p_string".into(), Scalar::String("hello".into())),
("p_int".into(), Scalar::Integer(42)),
("p_long".into(), Scalar::Long(9_876_543_210)),
("p_short".into(), Scalar::Short(7)),
("p_byte".into(), Scalar::Byte(3)),
("p_float".into(), Scalar::Float(1.25)),
("p_double".into(), Scalar::Double(99.99)),
("p_boolean".into(), Scalar::Boolean(true)),
("p_date".into(), Scalar::Date(date_to_days("2025-03-31"))),
("p_timestamp".into(), Scalar::Timestamp(ts)),
("p_decimal".into(), Scalar::decimal(12345, 10, 2)?),
("p_binary".into(), Scalar::Binary(b"Hello".to_vec())),
("p_timestamp_ntz".into(), Scalar::TimestampNtz(ts)),
]))
}
const EXPECTED_NORMAL_PVS: &[(&str, &str)] = &[
("p_string", "hello"),
("p_int", "42"),
("p_long", "9876543210"),
("p_short", "7"),
("p_byte", "3"),
("p_float", "1.25"),
("p_double", "99.99"),
("p_boolean", "true"),
("p_date", "2025-03-31"),
("p_timestamp", "2025-03-31T15:30:00.123456Z"),
("p_decimal", "123.45"),
("p_binary", "Hello"),
("p_timestamp_ntz", "2025-03-31 15:30:00.123456"),
];
fn null_arrow_columns() -> Vec<ArrayRef> {
vec![
Arc::new(Int32Array::from(vec![1])),
Arc::new(StringArray::from(vec![None::<&str>])),
Arc::new(Int32Array::from(vec![None::<i32>])),
Arc::new(Int64Array::from(vec![None::<i64>])),
Arc::new(Int16Array::from(vec![None::<i16>])),
Arc::new(Int8Array::from(vec![None::<i8>])),
Arc::new(Float32Array::from(vec![None::<f32>])),
Arc::new(Float64Array::from(vec![None::<f64>])),
Arc::new(BooleanArray::from(vec![None::<bool>])),
Arc::new(Date32Array::from(vec![None::<i32>])),
Arc::new(TimestampMicrosecondArray::from(vec![None::<i64>]).with_timezone("UTC")),
Arc::new(
Decimal128Array::from(vec![None::<i128>])
.with_precision_and_scale(10, 2)
.unwrap(),
),
Arc::new(BinaryArray::from(vec![None::<&[u8]>])),
Arc::new(TimestampMicrosecondArray::from(vec![None::<i64>])),
]
}
fn null_partition_values() -> Result<HashMap<String, Scalar>, Box<dyn std::error::Error>> {
Ok(HashMap::from([
("p_string".into(), Scalar::Null(DataType::STRING)),
("p_int".into(), Scalar::Null(DataType::INTEGER)),
("p_long".into(), Scalar::Null(DataType::LONG)),
("p_short".into(), Scalar::Null(DataType::SHORT)),
("p_byte".into(), Scalar::Null(DataType::BYTE)),
("p_float".into(), Scalar::Null(DataType::FLOAT)),
("p_double".into(), Scalar::Null(DataType::DOUBLE)),
("p_boolean".into(), Scalar::Null(DataType::BOOLEAN)),
("p_date".into(), Scalar::Null(DataType::DATE)),
("p_timestamp".into(), Scalar::Null(DataType::TIMESTAMP)),
("p_decimal".into(), Scalar::Null(DataType::decimal(10, 2)?)),
("p_binary".into(), Scalar::Null(DataType::BINARY)),
(
"p_timestamp_ntz".into(),
Scalar::Null(DataType::TIMESTAMP_NTZ),
),
]))
}
macro_rules! assert_col {
($batch:expr, $idx:expr, $arr_ty:ty, $expected:expr) => {
assert_eq!(
$batch
.column($idx)
.as_any()
.downcast_ref::<$arr_ty>()
.unwrap()
.value(0),
$expected,
"column {} ({}) value mismatch",
$idx,
$batch.schema().field($idx).name()
);
};
}
fn assert_normal_values(sorted: &RecordBatch) {
let ts = ts_to_micros("2025-03-31 15:30:00.123456");
assert_eq!(sorted.num_rows(), 1);
assert_col!(sorted, 0, Int32Array, 1); assert_col!(sorted, 1, StringArray, "hello"); assert_col!(sorted, 2, Int32Array, 42); assert_col!(sorted, 3, Int64Array, 9_876_543_210i64); assert_col!(sorted, 4, Int16Array, 7i16); assert_col!(sorted, 5, Int8Array, 3i8); assert_col!(sorted, 6, Float32Array, 1.25f32); assert_col!(sorted, 7, Float64Array, 99.99f64); assert_col!(sorted, 8, BooleanArray, true); assert_col!(sorted, 9, Date32Array, date_to_days("2025-03-31")); assert_col!(sorted, 10, TimestampMicrosecondArray, ts); assert_col!(sorted, 11, Decimal128Array, 12345); assert_eq!(
sorted
.column(12)
.as_any()
.downcast_ref::<BinaryArray>()
.unwrap()
.value(0),
b"Hello"
);
assert_col!(sorted, 13, TimestampMicrosecondArray, ts); }
fn assert_all_partition_columns_null(sorted: &RecordBatch) {
assert_eq!(sorted.num_rows(), 1);
for col_idx in 1..=13 {
assert!(
sorted.column(col_idx).is_null(0),
"partition column at index {col_idx} ({}) should be null",
sorted.schema().field(col_idx).name()
);
}
}
fn assert_cm_path(rel_path: &str) {
let segments: Vec<&str> = rel_path.split('/').collect();
assert_eq!(
segments.len(),
2,
"CM on: path should be <prefix>/<file>, got: {rel_path}"
);
assert_eq!(segments[0].len(), 2, "prefix should be 2 chars");
assert!(segments[0].chars().all(|c| c.is_ascii_alphanumeric()));
assert!(segments[1].ends_with(".parquet"));
}
fn cm_mode_str(mode: ColumnMappingMode) -> &'static str {
match mode {
ColumnMappingMode::None => "none",
ColumnMappingMode::Id => "id",
ColumnMappingMode::Name => "name",
}
}
fn create_partitioned_table(
table_path: &str,
engine: &dyn delta_kernel::Engine,
schema: Arc<StructType>,
partition_cols: &[&str],
cm_mode: ColumnMappingMode,
) -> Result<Arc<Snapshot>, Box<dyn std::error::Error>> {
let mut builder = create_table(table_path, schema, "test/1.0")
.with_data_layout(DataLayout::partitioned(partition_cols));
if cm_mode != ColumnMappingMode::None {
builder =
builder.with_table_properties([("delta.columnMapping.mode", cm_mode_str(cm_mode))]);
}
let _ = builder
.build(engine, Box::new(FileSystemCommitter::new()))?
.commit(engine)?;
Ok(Snapshot::builder_for(table_path).build(engine)?)
}
fn read_sorted(
snapshot: &Arc<Snapshot>,
engine: Arc<dyn delta_kernel::Engine>,
) -> Result<RecordBatch, Box<dyn std::error::Error>> {
let scan = snapshot.clone().scan_builder().build()?;
let batches = read_scan(&scan, engine)?;
assert!(!batches.is_empty(), "expected at least one batch");
let merged = delta_kernel::arrow::compute::concat_batches(&batches[0].schema(), &batches)?;
let sort_indices = delta_kernel::arrow::compute::sort_to_indices(merged.column(0), None, None)?;
let sorted_columns: Vec<ArrayRef> = merged
.columns()
.iter()
.map(|col| delta_kernel::arrow::compute::take(col.as_ref(), &sort_indices, None).unwrap())
.collect();
Ok(RecordBatch::try_new(merged.schema(), sorted_columns)?)
}
fn ts_to_micros(s: &str) -> i64 {
let ndt = NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").unwrap();
Utc.from_utc_datetime(&ndt)
.signed_duration_since(chrono::DateTime::UNIX_EPOCH)
.num_microseconds()
.unwrap()
}
fn date_to_days(s: &str) -> i32 {
let date = NaiveDate::parse_from_str(s, "%Y-%m-%d").unwrap();
let dt = Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap());
dt.signed_duration_since(chrono::DateTime::UNIX_EPOCH)
.num_days() as i32
}
fn ts_array(micros: i64) -> ArrayRef {
Arc::new(TimestampMicrosecondArray::from(vec![micros]).with_timezone("UTC"))
}
fn ts_ntz_array(micros: i64) -> ArrayRef {
Arc::new(TimestampMicrosecondArray::from(vec![micros]))
}
fn decimal_array(value: i128, precision: u8, scale: i8) -> ArrayRef {
Arc::new(
Decimal128Array::from(vec![value])
.with_precision_and_scale(precision, scale)
.unwrap(),
)
}
async fn setup_and_write(
schema: Arc<StructType>,
partition_cols: &[&str],
cm_mode: ColumnMappingMode,
arrow_columns: Vec<ArrayRef>,
partition_values: HashMap<String, Scalar>,
) -> Result<
(
tempfile::TempDir,
String,
Arc<Snapshot>,
Arc<dyn delta_kernel::Engine>,
),
Box<dyn std::error::Error>,
> {
let (tmp_dir, table_path, engine) = test_table_setup_mt()?;
let arrow_schema: Arc<ArrowSchema> = Arc::new(schema.as_ref().try_into_arrow()?);
let snapshot = create_partitioned_table(
&table_path,
engine.as_ref(),
schema,
partition_cols,
cm_mode,
)?;
let batch = RecordBatch::try_new(arrow_schema, arrow_columns)?;
let snapshot =
write_batch_to_table(&snapshot, engine.as_ref(), batch, partition_values).await?;
Ok((
tmp_dir,
table_path,
snapshot,
engine as Arc<dyn delta_kernel::Engine>,
))
}
fn verify_and_checkpoint(
snapshot: &Arc<Snapshot>,
engine: Arc<dyn delta_kernel::Engine>,
assert_fn: fn(&RecordBatch),
) -> Result<(), Box<dyn std::error::Error>> {
let sorted = read_sorted(snapshot, engine.clone())?;
assert_fn(&sorted);
snapshot.checkpoint(engine.as_ref())?;
let reloaded = Snapshot::builder_for(snapshot.table_root()).build(engine.as_ref())?;
let sorted = read_sorted(&reloaded, engine)?;
assert_fn(&sorted);
Ok(())
}
fn hive_prefix(cols: &[&str], value: &str) -> String {
cols.iter()
.map(|c| format!("{c}={value}"))
.collect::<Vec<_>>()
.join("/")
+ "/"
}
fn read_add_actions_json(
table_path: &str,
version: u64,
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
let commit_path = format!("{table_path}/_delta_log/{version:020}.json");
let content = std::fs::read_to_string(commit_path)?;
let parsed: Vec<serde_json::Value> = serde_json::Deserializer::from_str(&content)
.into_iter::<serde_json::Value>()
.collect::<Result<Vec<_>, _>>()?;
Ok(parsed
.into_iter()
.filter_map(|v| v.get("add").cloned())
.collect())
}
fn read_single_add(
table_path: &str,
version: u64,
) -> Result<(serde_json::Value, String), Box<dyn std::error::Error>> {
let adds = read_add_actions_json(table_path, version)?;
assert_eq!(adds.len(), 1, "should have exactly one add action");
let add = adds.into_iter().next().unwrap();
let rel_path = add["path"].as_str().unwrap().to_string();
assert!(
!rel_path.contains("://"),
"should produce relative paths, got: {rel_path}"
);
Ok((add, rel_path))
}