use arrow_array::{Array, RecordBatch};
use arrow_schema::{DataType, Schema};
use datafusion::common::ScalarValue;
use datafusion::error::{DataFusionError, Result as DFResult};
use lance_core::{Error, Result};
pub const ROW_ADDRESS_COLUMN: &str = "_rowaddr";
pub fn resolve_pk_indices(batch: &RecordBatch, pk_columns: &[String]) -> DFResult<Vec<usize>> {
pk_columns
.iter()
.map(|col| {
batch
.schema()
.column_with_name(col)
.map(|(idx, _)| idx)
.ok_or_else(|| {
DataFusionError::Internal(format!("Primary key column '{}' not found", col))
})
})
.collect()
}
pub fn is_supported_pk_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Boolean
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Binary
| DataType::LargeBinary
)
}
pub fn validate_pk_types(schema: &Schema, pk_columns: &[String]) -> Result<()> {
for col in pk_columns {
let field = schema.field_with_name(col).map_err(|_| {
Error::invalid_input(format!("Primary key column '{}' not found in schema", col))
})?;
if !is_supported_pk_type(field.data_type()) {
return Err(Error::invalid_input(format!(
"Primary key column '{}' has unsupported type {:?} for hashing; supported types: \
Int8/16/32/64, UInt8/16/32/64, Boolean, Utf8/LargeUtf8, Binary/LargeBinary",
col,
field.data_type()
)));
}
}
Ok(())
}
pub fn compute_pk_hash(batch: &RecordBatch, pk_indices: &[usize], row_idx: usize) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
for &col_idx in pk_indices {
let col = batch.column(col_idx);
let is_null = col.is_null(row_idx);
is_null.hash(&mut hasher);
if !is_null {
if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int8Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int16Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int32Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt8Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt16Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt32Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::BooleanArray>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::StringArray>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::LargeStringArray>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::BinaryArray>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
arr.value(row_idx).hash(&mut hasher);
} else if let Ok(scalar) = ScalarValue::try_from_array(col.as_ref(), row_idx) {
format!("{:?}", scalar).hash(&mut hasher);
}
}
}
hasher.finish()
}