use std::collections::HashMap;
use crate::expressions::Scalar;
use crate::schema::{DataType, StructType};
use crate::{DeltaResult, Error};
pub(crate) fn validate_partition_values(
logical_partition_columns: &[String],
logical_schema: &StructType,
logical_partition_values: HashMap<String, Scalar>,
) -> DeltaResult<HashMap<String, Scalar>> {
let normalized = validate_keys(logical_partition_columns, logical_partition_values)?;
validate_types(logical_schema, &normalized)?;
Ok(normalized)
}
fn validate_keys(
logical_partition_columns: &[String],
logical_partition_values: HashMap<String, Scalar>,
) -> DeltaResult<HashMap<String, Scalar>> {
let schema_lookup: HashMap<String, &str> = logical_partition_columns
.iter()
.map(|name| (name.to_lowercase(), name.as_str()))
.collect();
let mut normalized = HashMap::with_capacity(logical_partition_values.len());
for (key, value) in logical_partition_values {
let lower_key = key.to_lowercase();
let schema_name = schema_lookup.get(&lower_key).ok_or_else(|| {
Error::invalid_partition_values(format!(
"unknown partition column '{key}'. Expected one of: [{}]",
logical_partition_columns.join(", ")
))
})?;
if normalized.contains_key(*schema_name) {
return Err(Error::invalid_partition_values(format!(
"duplicate partition column '{key}' (normalized to same key as a previously provided entry)"
)));
}
normalized.insert(schema_name.to_string(), value);
}
for col in logical_partition_columns {
if !normalized.contains_key(col.as_str()) {
return Err(Error::invalid_partition_values(format!(
"missing partition column '{col}'. Provided: [{}]",
normalized.keys().cloned().collect::<Vec<_>>().join(", ")
)));
}
}
Ok(normalized)
}
fn validate_types(
logical_schema: &StructType,
logical_partition_values: &HashMap<String, Scalar>,
) -> DeltaResult<()> {
for (col_name, value) in logical_partition_values {
let field = logical_schema.field(col_name).ok_or_else(|| {
Error::invalid_partition_values(format!(
"partition column '{col_name}' not found in table schema"
))
})?;
let expected_type = field.data_type();
if matches!(
expected_type,
DataType::Struct(_) | DataType::Array(_) | DataType::Map(_)
) {
return Err(Error::invalid_partition_values(format!(
"partition column '{col_name}' has non-primitive type {expected_type:?}. \
Partition columns must be primitive types."
)));
}
if value.is_null() {
continue;
}
let actual_type = value.data_type();
if *expected_type != actual_type {
return Err(Error::invalid_partition_values(format!(
"partition column '{col_name}' has type {expected_type:?} but got \
value of type {actual_type:?}"
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::expressions::Scalar;
use crate::schema::{ArrayType, DataType, MapType, StructField};
fn assert_type_ok(data_type: DataType, value: Scalar) {
let schema = StructType::new_unchecked(vec![StructField::not_null("p", data_type)]);
let values = HashMap::from([("p".to_string(), value)]);
validate_types(&schema, &values).unwrap();
}
fn assert_type_err(data_type: DataType, value: Scalar) -> String {
let schema = StructType::new_unchecked(vec![StructField::not_null("p", data_type)]);
let values = HashMap::from([("p".to_string(), value)]);
validate_types(&schema, &values).unwrap_err().to_string()
}
#[test]
fn test_validate_partition_keys_matching_keys_returns_ok() {
let cols = vec!["year".to_string(), "region".to_string()];
let values = HashMap::from([
("year".to_string(), Scalar::Integer(2024)),
("region".to_string(), Scalar::String("US".into())),
]);
let result = validate_keys(&cols, values).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result.get("year"), Some(&Scalar::Integer(2024)));
assert_eq!(result.get("region"), Some(&Scalar::String("US".into())));
}
#[test]
fn test_validate_partition_keys_empty_columns_and_values_returns_ok() {
let cols: Vec<String> = vec![];
let values = HashMap::new();
let result = validate_keys(&cols, values).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_validate_partition_keys_missing_key_returns_error() {
let cols = vec!["year".to_string(), "region".to_string()];
let values = HashMap::from([("year".to_string(), Scalar::Integer(2024))]);
let result = validate_keys(&cols, values);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("missing partition column 'region'"), "{err}");
}
#[test]
fn test_validate_partition_keys_extra_key_returns_error() {
let cols = vec!["year".to_string()];
let values = HashMap::from([
("year".to_string(), Scalar::Integer(2024)),
("region".to_string(), Scalar::String("US".into())),
]);
let result = validate_keys(&cols, values);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("unknown partition column 'region'"), "{err}");
}
#[test]
fn test_validate_partition_keys_case_normalizes_to_schema_case() {
let cols = vec!["Year".to_string()];
let values = HashMap::from([("YEAR".to_string(), Scalar::Integer(2024))]);
let result = validate_keys(&cols, values).unwrap();
assert!(result.contains_key("Year"));
assert!(!result.contains_key("YEAR"));
}
#[test]
fn test_validate_partition_keys_duplicate_after_normalization_returns_error() {
let cols = vec!["col".to_string()];
let values = HashMap::from([
("COL".to_string(), Scalar::Integer(1)),
("col".to_string(), Scalar::Integer(2)),
]);
let result = validate_keys(&cols, values);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("duplicate"), "{err}");
}
#[rstest]
#[case(DataType::INTEGER, Scalar::Integer(0))] #[case(DataType::INTEGER, Scalar::Integer(-1))] #[case(DataType::INTEGER, Scalar::Integer(i32::MAX))] #[case(DataType::INTEGER, Scalar::Integer(i32::MIN))] #[case(DataType::LONG, Scalar::Long(i64::MAX))] #[case(DataType::LONG, Scalar::Long(i64::MIN))] #[case(DataType::BYTE, Scalar::Byte(127))] #[case(DataType::BYTE, Scalar::Byte(-128))] #[case(DataType::SHORT, Scalar::Short(32767))] #[case(DataType::SHORT, Scalar::Short(-32768))] #[case(DataType::DOUBLE, Scalar::Double(0.0))] #[case(DataType::DOUBLE, Scalar::Double(-0.0))] #[case(DataType::DOUBLE, Scalar::Double(f64::MAX))] #[case(DataType::DOUBLE, Scalar::Double(f64::MIN_POSITIVE))] #[case(DataType::DOUBLE, Scalar::Double(f64::NAN))] #[case(DataType::DOUBLE, Scalar::Double(f64::INFINITY))] #[case(DataType::DOUBLE, Scalar::Double(f64::NEG_INFINITY))] #[case(DataType::FLOAT, Scalar::Float(0.0))] #[case(DataType::FLOAT, Scalar::Float(f32::NAN))] #[case(DataType::FLOAT, Scalar::Float(f32::INFINITY))] #[case(DataType::FLOAT, Scalar::Float(f32::NEG_INFINITY))] #[case(DataType::BOOLEAN, Scalar::Boolean(true))] #[case(DataType::BOOLEAN, Scalar::Boolean(false))] #[case(DataType::decimal(38, 18).unwrap(), Scalar::decimal(0, 38, 18).unwrap())] #[case(DataType::decimal(38, 18).unwrap(), Scalar::decimal(1_230_000_000_000_000_000i128, 38, 18).unwrap())] #[case(DataType::decimal(38, 18).unwrap(), Scalar::decimal(-1_230_000_000_000_000_000i128, 38, 18).unwrap())] #[case(DataType::DATE, Scalar::Date(19723))] #[case(DataType::DATE, Scalar::Date(0))] #[case(DataType::DATE, Scalar::Date(-719_162))] #[case(DataType::DATE, Scalar::Date(2_932_896))] #[case(DataType::TIMESTAMP, Scalar::Timestamp(1_718_451_045_000_000))] #[case(DataType::TIMESTAMP, Scalar::Timestamp(0))] #[case(DataType::TIMESTAMP, Scalar::Timestamp(1_718_499_599_999_999))] #[case(DataType::TIMESTAMP_NTZ, Scalar::TimestampNtz(1_718_451_045_000_000))] #[case(DataType::TIMESTAMP_NTZ, Scalar::TimestampNtz(0))] fn test_validate_types_encoding_table_rows_return_ok(
#[case] data_type: DataType,
#[case] value: Scalar,
) {
assert_type_ok(data_type, value);
}
#[rstest]
#[case(DataType::STRING, Scalar::String("\x00".into()))] #[case(DataType::STRING, Scalar::String("before\x00after".into()))] #[case(DataType::STRING, Scalar::String("a{b".into()))] #[case(DataType::STRING, Scalar::String("a}b".into()))] #[case(DataType::STRING, Scalar::String("hello world".into()))] #[case(DataType::STRING, Scalar::String("M\u{00FC}nchen".into()))] #[case(DataType::STRING, Scalar::String("\u{65E5}\u{672C}\u{8A9E}".into()))] #[case(DataType::STRING, Scalar::String("\u{1F3B5}\u{1F3B6}".into()))] #[case(DataType::STRING, Scalar::String("a<b>c|d".into()))] #[case(DataType::STRING, Scalar::String("a@b!c(d)".into()))] #[case(DataType::STRING, Scalar::String("a&b+c$d;e,f".into()))] #[case(DataType::STRING, Scalar::String("Serbia/srb%".into()))] #[case(DataType::STRING, Scalar::String("100%25".into()))] #[case(DataType::STRING, Scalar::String("".into()))] #[case(DataType::STRING, Scalar::String(" ".into()))] #[case(DataType::STRING, Scalar::String(" ".into()))] #[case(DataType::BINARY, Scalar::Binary(vec![]))] #[case(DataType::BINARY, Scalar::Binary(vec![0xDE, 0xAD, 0xBE, 0xEF]))] #[case(DataType::BINARY, Scalar::Binary(vec![0x48, 0x45, 0x4C, 0x4C, 0x4F]))] #[case(DataType::BINARY, Scalar::Binary(vec![0x00, 0xFF]))] #[case(DataType::BINARY, Scalar::Binary(vec![0x2F, 0x3D, 0x25]))] fn test_validate_types_string_binary_table_rows_return_ok(
#[case] data_type: DataType,
#[case] value: Scalar,
) {
assert_type_ok(data_type, value);
}
#[rstest]
#[case(DataType::INTEGER, Scalar::Null(DataType::INTEGER))] #[case(DataType::LONG, Scalar::Null(DataType::LONG))] #[case(DataType::BYTE, Scalar::Null(DataType::BYTE))] #[case(DataType::SHORT, Scalar::Null(DataType::SHORT))] #[case(DataType::DOUBLE, Scalar::Null(DataType::DOUBLE))] #[case(DataType::FLOAT, Scalar::Null(DataType::FLOAT))] #[case(DataType::BOOLEAN, Scalar::Null(DataType::BOOLEAN))] #[case(DataType::decimal(38, 18).unwrap(), Scalar::Null(DataType::decimal(38, 18).unwrap()))] #[case(DataType::DATE, Scalar::Null(DataType::DATE))] #[case(DataType::TIMESTAMP, Scalar::Null(DataType::TIMESTAMP))] #[case(DataType::TIMESTAMP_NTZ, Scalar::Null(DataType::TIMESTAMP_NTZ))] #[case(DataType::STRING, Scalar::Null(DataType::STRING))] #[case(DataType::BINARY, Scalar::Null(DataType::BINARY))] fn test_validate_types_null_returns_ok(#[case] data_type: DataType, #[case] value: Scalar) {
assert_type_ok(data_type, value);
}
#[test]
fn test_validate_types_null_with_mismatched_inner_type_returns_ok() {
assert_type_ok(DataType::INTEGER, Scalar::Null(DataType::STRING));
}
#[rstest]
#[case(DataType::STRING, Scalar::Integer(1))]
#[case(DataType::INTEGER, Scalar::String("x".into()))]
#[case(DataType::INTEGER, Scalar::Long(1))]
#[case(DataType::LONG, Scalar::Integer(1))]
#[case(DataType::DOUBLE, Scalar::Float(1.0))]
#[case(DataType::FLOAT, Scalar::Double(1.0))]
#[case(DataType::DATE, Scalar::Timestamp(0))]
#[case(DataType::TIMESTAMP, Scalar::TimestampNtz(0))]
#[case(DataType::STRING, Scalar::Binary(vec![0x41]))]
#[case(DataType::BINARY, Scalar::String("A".into()))]
#[case(DataType::BOOLEAN, Scalar::Integer(1))]
fn test_validate_types_mismatch_returns_error(
#[case] data_type: DataType,
#[case] value: Scalar,
) {
let err = assert_type_err(data_type, value);
assert!(err.contains("p"), "{err}");
}
#[rstest]
#[case(
DataType::Struct(Box::new(StructType::new_unchecked(vec![
StructField::not_null("x", DataType::INTEGER),
]))),
Scalar::Null(DataType::STRING)
)]
#[case(
DataType::Array(Box::new(ArrayType::new(DataType::INTEGER, false))),
Scalar::Null(DataType::STRING)
)]
#[case(
DataType::Map(Box::new(MapType::new(DataType::STRING, DataType::INTEGER, false))),
Scalar::Null(DataType::STRING)
)]
fn test_validate_types_complex_type_returns_error(
#[case] data_type: DataType,
#[case] value: Scalar,
) {
let err = assert_type_err(data_type, value);
assert!(err.contains("non-primitive type"), "{err}");
}
#[test]
fn test_validate_types_column_not_in_schema_returns_error() {
let schema =
StructType::new_unchecked(vec![StructField::not_null("year", DataType::INTEGER)]);
let values = HashMap::from([("nonexistent".to_string(), Scalar::Integer(42))]);
let result = validate_types(&schema, &values);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not found in table schema"), "{err}");
}
#[test]
fn test_validate_partition_values_with_case_mismatch_succeeds() {
let partition_cols = vec!["Year".to_string(), "Region".to_string()];
let schema = StructType::new_unchecked(vec![
StructField::not_null("id", DataType::INTEGER),
StructField::not_null("Year", DataType::INTEGER),
StructField::nullable("Region", DataType::STRING),
]);
let values = HashMap::from([
("YEAR".to_string(), Scalar::Integer(2024)),
("region".to_string(), Scalar::String("US".into())),
]);
let result = validate_partition_values(&partition_cols, &schema, values).unwrap();
assert_eq!(result.get("Year"), Some(&Scalar::Integer(2024)));
assert_eq!(result.get("Region"), Some(&Scalar::String("US".into())));
}
}