use std::{collections::HashSet, sync::Arc};
use super::fragment::FileFragment;
use super::{
Dataset,
transaction::{Operation, Transaction},
};
use crate::{Error, Result, io::exec::Planner};
use arrow::compute::CastOptions;
use arrow::compute::can_cast_types;
use arrow_array::{Array, RecordBatch, RecordBatchReader};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use datafusion::execution::SendableRecordBatchStream;
use futures::stream::{StreamExt, TryStreamExt};
use lance_arrow::SchemaExt;
use lance_core::datatypes::{Field, Schema};
use lance_datafusion::utils::StreamingWriteSource;
use lance_encoding::constants::{PACKED_STRUCT_LEGACY_META_KEY, PACKED_STRUCT_META_KEY};
use lance_encoding::version::LanceFileVersion;
use lance_table::format::Fragment;
mod optimize;
use optimize::{
ChainedNewColumnTransformOptimizer, NewColumnTransformOptimizer, SqlToAllNullsOptimizer,
};
async fn validate_no_nulls_before_making_non_nullable(dataset: &Dataset, path: &str) -> Result<()> {
let field = dataset.schema().field(path).ok_or_else(|| {
Error::invalid_input(format!("Column \"{}\" does not exist in the dataset", path))
})?;
if !field.nullable {
return Ok(());
}
let mut scanner = dataset.scan();
scanner.project(&[path])?;
let mut stream = scanner.try_into_stream().await?;
while let Some(batch) = stream.try_next().await? {
if batch.num_columns() != 1 {
return Err(Error::internal(format!(
"Expected exactly one column in validation scan for {}, got {}",
path,
batch.num_columns()
)));
}
let col = batch.column(0);
if col.null_count() > 0 {
return Err(Error::invalid_input(format!(
"Column \"{}\" contains NULL values and cannot be made non-nullable",
path
)));
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
pub struct BatchInfo {
pub fragment_id: u32,
pub batch_index: usize,
}
pub trait UDFCheckpointStore: Send + Sync {
fn get_batch(&self, info: &BatchInfo) -> Result<Option<RecordBatch>>;
fn insert_batch(&self, info: BatchInfo, batch: RecordBatch) -> Result<()>;
fn get_fragment(&self, fragment_id: u32) -> Result<Option<Fragment>>;
fn insert_fragment(&self, fragment: Fragment) -> Result<()>;
}
pub struct BatchUDF {
#[allow(clippy::type_complexity)]
pub mapper: Box<dyn Fn(&RecordBatch) -> Result<RecordBatch> + Send + Sync>,
pub output_schema: Arc<ArrowSchema>,
pub result_checkpoint: Option<Arc<dyn UDFCheckpointStore>>,
}
pub enum NewColumnTransform {
BatchUDF(BatchUDF),
SqlExpressions(Vec<(String, String)>),
Stream(SendableRecordBatchStream),
Reader(Box<dyn RecordBatchReader + Send>),
AllNulls(Arc<ArrowSchema>),
}
pub struct ColumnAlteration {
pub path: String,
pub rename: Option<String>,
pub nullable: Option<bool>,
pub data_type: Option<DataType>,
}
impl ColumnAlteration {
pub fn new(path: String) -> Self {
Self {
path,
rename: None,
nullable: None,
data_type: None,
}
}
pub fn rename(mut self, name: String) -> Self {
self.rename = Some(name);
self
}
pub fn set_nullable(mut self, nullable: bool) -> Self {
self.nullable = Some(nullable);
self
}
pub fn cast_to(mut self, data_type: DataType) -> Self {
self.data_type = Some(data_type);
self
}
}
fn is_upcast_downcast(from_type: &DataType, to_type: &DataType) -> bool {
use DataType::*;
match from_type {
from_type if from_type.is_integer() => to_type.is_integer(),
from_type if from_type.is_floating() => to_type.is_floating(),
from_type if from_type.is_temporal() => to_type.is_temporal(),
Boolean => matches!(to_type, Boolean),
Utf8 | LargeUtf8 => matches!(to_type, Utf8 | LargeUtf8),
Binary | LargeBinary => matches!(to_type, Binary | LargeBinary),
Decimal128(_, _) | Decimal256(_, _) => {
matches!(to_type, Decimal128(_, _) | Decimal256(_, _))
}
List(from_field) | LargeList(from_field) | FixedSizeList(from_field, _) => match to_type {
List(to_field) | LargeList(to_field) | FixedSizeList(to_field, _) => {
is_upcast_downcast(from_field.data_type(), to_field.data_type())
}
_ => false,
},
_ => false,
}
}
trait ArrowFieldExt {
fn is_packed(&self) -> bool;
}
impl ArrowFieldExt for ArrowField {
fn is_packed(&self) -> bool {
let metadata = self.metadata();
metadata
.get(PACKED_STRUCT_LEGACY_META_KEY)
.map(|v| v == "true")
.unwrap_or(metadata.contains_key(PACKED_STRUCT_META_KEY))
}
}
fn check_field_conflict(
left: &ArrowField,
right: &ArrowField,
version: &LanceFileVersion,
) -> Result<()> {
if left.name() != right.name() {
return Ok(());
}
match (left.data_type(), right.data_type()) {
(DataType::Struct(fl), DataType::Struct(fr)) => {
if !version.support_add_sub_column() {
return Err(Error::invalid_input(format!(
"Column {} is a struct col, add sub column is not supported in Lance file version {}",
left.name(),
version
)));
}
if left.is_packed() || right.is_packed() {
return Err(Error::invalid_input(format!(
"Column {} is packed struct and already exists in the dataset",
left.name()
)));
}
for l_field in fl.iter() {
if let Some((_, r_field)) = fr.find(l_field.name()) {
check_field_conflict(l_field, r_field, version)?;
}
}
Ok(())
}
(DataType::List(fl), DataType::List(fr)) => check_field_conflict(fl, fr, version),
(DataType::LargeList(fl), DataType::LargeList(fr)) => check_field_conflict(fl, fr, version),
(DataType::FixedSizeList(fl, _), DataType::FixedSizeList(fr, _)) => {
check_field_conflict(fl, fr, version)
}
(l_type, r_type) if l_type == r_type => Err(Error::invalid_input(format!(
"Column {} already exists in the dataset",
left.name()
))),
(_, _) => Err(Error::invalid_input(format!(
"Type conflicts between {}({}) and {}({})",
left.name(),
left.data_type(),
right.name(),
right.data_type()
))),
}
}
pub(super) async fn add_columns_to_fragments(
dataset: &Dataset,
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
fragments: &[FileFragment],
batch_size: Option<u32>,
) -> Result<(Vec<Fragment>, Schema)> {
let version = dataset.manifest.data_storage_format.lance_file_version()?;
let check_names = |output_schema: &ArrowSchema| {
for field in &dataset.schema().fields {
if let Ok(out_field) = output_schema.field_with_name(&field.name) {
let ds_field = ArrowField::from(field);
check_field_conflict(&ds_field, out_field, &version)?;
}
}
Ok::<(), Error>(())
};
let mut optimizer = ChainedNewColumnTransformOptimizer::new(vec![]);
if !dataset.is_legacy_storage() {
optimizer.add_optimizer(Box::new(SqlToAllNullsOptimizer::new()));
}
let transforms = optimizer.optimize(dataset, transforms)?;
let (output_schema, fragments) = match transforms {
NewColumnTransform::BatchUDF(udf) => {
check_names(udf.output_schema.as_ref())?;
let fragments = add_columns_impl(
fragments,
read_columns,
udf.mapper,
batch_size,
udf.result_checkpoint,
None,
)
.await?;
Result::Ok((udf.output_schema, fragments))
}
NewColumnTransform::SqlExpressions(expressions) => {
let arrow_schema = Arc::new(ArrowSchema::from(dataset.schema()));
let planner = Planner::new(arrow_schema);
let exprs = expressions
.into_iter()
.map(|(name, expr)| {
let expr = planner.parse_expr(&expr)?;
let expr = planner.optimize_expr(expr)?;
Ok((name, expr))
})
.collect::<Result<Vec<_>>>()?;
let needed_columns = exprs
.iter()
.flat_map(|(_, expr)| Planner::column_names_in_expr(expr))
.collect::<HashSet<_>>()
.into_iter()
.collect::<Vec<_>>();
let read_schema = dataset.schema().project(&needed_columns)?;
let read_schema = Arc::new(ArrowSchema::from(&read_schema));
let planner = Planner::new(read_schema.clone());
let exprs = exprs
.into_iter()
.map(|(name, expr)| {
let expr = planner.create_physical_expr(&expr)?;
Ok((name, expr))
})
.collect::<Result<Vec<_>>>()?;
let output_schema = Arc::new(ArrowSchema::new(
exprs
.iter()
.map(|(name, expr)| {
Ok(ArrowField::new(
name,
expr.data_type(read_schema.as_ref())?,
expr.nullable(read_schema.as_ref())?,
))
})
.collect::<Result<Vec<_>>>()?,
));
check_names(output_schema.as_ref())?;
let schema_ref = output_schema.clone();
let mapper = move |batch: &RecordBatch| {
let num_rows = batch.num_rows();
let columns = exprs
.iter()
.map(|(_, expr)| Ok(expr.evaluate(batch)?.into_array(num_rows)?))
.collect::<Result<Vec<_>>>()?;
let batch = RecordBatch::try_new(schema_ref.clone(), columns)?;
Ok(batch)
};
let mapper = Box::new(mapper);
let read_columns = Some(read_schema.field_names().into_iter().cloned().collect());
let fragments =
add_columns_impl(fragments, read_columns, mapper, batch_size, None, None).await?;
Ok((output_schema, fragments))
}
NewColumnTransform::Stream(stream) => {
let output_schema = stream.schema();
check_names(output_schema.as_ref())?;
let fragments = add_columns_from_stream(fragments, stream, None, batch_size).await?;
Ok((output_schema, fragments))
}
NewColumnTransform::Reader(reader) => {
let output_schema = reader.schema();
check_names(output_schema.as_ref())?;
let stream = reader.into_stream();
let fragments = add_columns_from_stream(fragments, stream, None, batch_size).await?;
Ok((output_schema, fragments))
}
NewColumnTransform::AllNulls(output_schema) => {
check_names(output_schema.as_ref())?;
let schema = Schema::try_from(output_schema.as_ref())?;
if !schema.all_fields_nullable() {
return Err(Error::invalid_input_source(
"All-null columns must be nullable.".into(),
));
}
let fragments = fragments
.iter()
.map(|f| f.metadata.clone())
.collect::<Vec<_>>();
if dataset.is_legacy_storage() {
return Err(Error::not_supported_source(
"Cannot add all-null columns to legacy dataset version.".into(),
));
}
Ok((output_schema, fragments))
}
}?;
let mut schema = dataset.schema().merge(output_schema.as_ref())?;
schema.set_field_id(Some(dataset.manifest.max_field_id()));
Ok((fragments, schema))
}
pub(super) async fn add_columns(
dataset: &mut Dataset,
transforms: NewColumnTransform,
read_columns: Option<Vec<String>>,
batch_size: Option<u32>,
) -> Result<()> {
let (fragments, schema) = add_columns_to_fragments(
dataset,
transforms,
read_columns,
&dataset.get_fragments(),
batch_size,
)
.await?;
let operation = Operation::Merge { fragments, schema };
let transaction = Transaction::new(dataset.manifest.version, operation, None);
dataset
.apply_commit(transaction, &Default::default(), &Default::default())
.await?;
Ok(())
}
#[allow(clippy::type_complexity)]
async fn add_columns_impl(
fragments: &[FileFragment],
read_columns: Option<Vec<String>>,
mapper: Box<dyn Fn(&RecordBatch) -> Result<RecordBatch> + Send + Sync>,
batch_size: Option<u32>,
result_cache: Option<Arc<dyn UDFCheckpointStore>>,
schemas: Option<(Schema, Schema)>,
) -> Result<Vec<Fragment>> {
let read_columns_ref = read_columns.as_deref();
let mapper_ref = mapper.as_ref();
let fragments = futures::stream::iter(fragments)
.then(|fragment| {
let cache_ref = result_cache.clone();
let schemas_ref = &schemas;
async move {
if let Some(cache) = &cache_ref {
let fragment_id = fragment.id() as u32;
let fragment = cache.get_fragment(fragment_id)?;
if let Some(fragment) = fragment {
return Ok(fragment);
}
}
let mut updater = fragment
.updater(read_columns_ref, schemas_ref.clone(), batch_size)
.await?;
let mut batch_index = 0;
while let Some(batch) = updater.next().await? {
let batch_info = BatchInfo {
fragment_id: fragment.id() as u32,
batch_index,
};
let new_batch = if let Some(cache) = &cache_ref {
if let Some(batch) = cache.get_batch(&batch_info)? {
batch
} else {
let new_batch = mapper_ref(batch)?;
cache.insert_batch(batch_info, new_batch.clone())?;
new_batch
}
} else {
mapper_ref(batch)?
};
updater.update(new_batch).await?;
batch_index += 1;
}
let fragment = updater.finish().await?;
if let Some(cache) = &cache_ref {
cache.insert_fragment(fragment.clone())?;
}
Ok::<_, Error>(fragment)
}
})
.try_collect::<Vec<_>>()
.await?;
Ok(fragments)
}
async fn add_columns_from_stream(
fragments: &[FileFragment],
mut stream: SendableRecordBatchStream,
schemas: Option<(Schema, Schema)>,
batch_size: Option<u32>,
) -> Result<Vec<Fragment>> {
let mut new_fragments = Vec::with_capacity(fragments.len());
let mut last_seen_batch: Option<RecordBatch> = None;
for fragment in fragments {
let mut updater = fragment
.updater::<String>(Some(&[]), schemas.clone(), batch_size)
.await?;
while let Some(batch) = updater.next().await? {
debug_assert_eq!(batch.num_columns(), 1);
let mut rows_remaining = batch.num_rows();
let mut batches = Vec::new();
while rows_remaining > 0 {
let next_batch = if let Some(last_seen_batch) = last_seen_batch {
last_seen_batch
} else {
stream.next().await.ok_or_else(|| {
Error::invalid_input(
"Stream ended before producing values for all rows in dataset",
)
})??
};
let num_rows = next_batch.num_rows();
if num_rows > rows_remaining {
let new_batch = next_batch.slice(0, rows_remaining);
batches.push(new_batch);
last_seen_batch =
Some(next_batch.slice(rows_remaining, num_rows - rows_remaining));
rows_remaining = 0;
} else {
batches.push(next_batch);
rows_remaining -= num_rows;
last_seen_batch = None;
}
}
let new_batch =
arrow_select::concat::concat_batches(&batches[0].schema(), batches.iter())?;
updater.update(new_batch).await?;
}
new_fragments.push(updater.finish().await?);
}
if last_seen_batch.is_some() || stream.next().await.is_some() {
return Err(Error::invalid_input_source(
"Stream produced more values than expected for dataset".into(),
));
}
Ok(new_fragments)
}
pub(super) async fn alter_columns(
dataset: &mut Dataset,
alterations: &[ColumnAlteration],
) -> Result<()> {
let mut new_schema = dataset.schema().clone();
let mut cast_fields: Vec<(Field, Field)> = Vec::new();
let mut next_field_id = dataset.manifest.max_field_id() + 1;
for alteration in alterations {
let field_src = dataset.schema().field(&alteration.path).ok_or_else(|| {
Error::invalid_input(format!(
"Column \"{}\" does not exist in the dataset",
alteration.path
))
})?;
if let Some(nullable) = alteration.nullable
&& field_src.nullable
&& !nullable
{
validate_no_nulls_before_making_non_nullable(dataset, &alteration.path).await?;
}
let field_dest = new_schema.mut_field_by_id(field_src.id).unwrap();
if let Some(rename) = &alteration.rename {
field_dest.name.clone_from(rename);
}
if let Some(nullable) = alteration.nullable {
field_dest.nullable = nullable;
}
if let Some(data_type) = &alteration.data_type {
if !(can_cast_types(&field_src.data_type(), data_type)
&& is_upcast_downcast(&field_src.data_type(), data_type))
{
return Err(Error::invalid_input(format!(
"Cannot cast column \"{}\" from {:?} to {:?}",
alteration.path,
field_src.data_type(),
data_type
)));
}
let arrow_field = ArrowField::new(
field_dest.name.clone(),
data_type.clone(),
field_dest.nullable,
);
*field_dest = Field::try_from(&arrow_field)?;
field_dest.set_id(field_src.parent_id, &mut next_field_id);
cast_fields.push((field_src.clone(), field_dest.clone()));
}
}
new_schema.validate()?;
let transaction = if cast_fields.is_empty() {
Transaction::new(
dataset.manifest.version,
Operation::Project { schema: new_schema },
None,
)
} else {
let read_columns = cast_fields
.iter()
.map(|(old, _new)| {
let parts = dataset.schema().field_ancestry_by_id(old.id).unwrap();
let part_names = parts.iter().map(|p| p.name.clone()).collect::<Vec<_>>();
part_names.join(".")
})
.collect::<Vec<_>>();
let new_ids = cast_fields
.iter()
.map(|(_old, new)| new.id)
.collect::<Vec<_>>();
let new_col_schema = new_schema.project_by_ids(&new_ids, true);
let mapper = move |batch: &RecordBatch| {
let mut fields = Vec::with_capacity(cast_fields.len());
let mut columns = Vec::with_capacity(batch.num_columns());
for (old, new) in &cast_fields {
let old_column = batch[&old.name].clone();
let new_column = lance_arrow::cast::cast_with_options(
&old_column,
&new.data_type(),
&CastOptions {
safe: false,
..Default::default()
},
)?;
columns.push(new_column);
fields.push(Arc::new(ArrowField::from(new)));
}
let schema = Arc::new(ArrowSchema::new(fields));
Ok(RecordBatch::try_new(schema, columns)?)
};
let mapper = Box::new(mapper);
let fragments = add_columns_impl(
&dataset.get_fragments(),
Some(read_columns),
mapper,
None,
None,
Some((new_col_schema, new_schema.clone())),
)
.await?;
let schema_field_ids = new_schema.field_ids().into_iter().collect::<Vec<_>>();
let fragments = fragments
.into_iter()
.map(|mut frag| {
frag.files.retain(|f| {
f.fields
.iter()
.any(|field| schema_field_ids.contains(field))
});
frag
})
.collect::<Vec<_>>();
Transaction::new(
dataset.manifest.version,
Operation::Merge {
schema: new_schema,
fragments,
},
None,
)
};
dataset
.apply_commit(transaction, &Default::default(), &Default::default())
.await?;
Ok(())
}
pub(super) async fn drop_columns(dataset: &mut Dataset, columns: &[&str]) -> Result<()> {
for col in columns {
if dataset.schema().field(col).is_none() {
return Err(Error::invalid_input(format!(
"Column {} does not exist in the dataset",
col
)));
}
}
let version = dataset.manifest.data_storage_format.lance_file_version()?;
let columns_to_remove = dataset.manifest.schema.project(columns)?;
let new_schema = exclude(&dataset.manifest.schema, &columns_to_remove, &version)?;
if new_schema.fields.is_empty() {
return Err(Error::invalid_input(
"Cannot drop all columns from a dataset",
));
}
let transaction = Transaction::new(
dataset.manifest.version,
Operation::Project { schema: new_schema },
None,
);
dataset
.apply_commit(transaction, &Default::default(), &Default::default())
.await?;
Ok(())
}
pub fn exclude(source: &Schema, other: &Schema, version: &LanceFileVersion) -> Result<Schema> {
let other: Schema = other.try_into().map_err(|_| {
Error::schema("The other schema is not compatible with this schema".to_string())
})?;
let mut fields = vec![];
for field in source.fields.iter() {
if let Some(other_field) = other.field(&field.name) {
if version.support_remove_sub_column(field)
&& let Some(f) = field.exclude(other_field)
{
fields.push(f)
}
} else {
fields.push(field.clone());
}
}
Ok(Schema {
fields,
metadata: source.metadata.clone(),
})
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::sync::Mutex;
use crate::dataset::WriteParams;
use arrow_array::{
ArrayRef, Int32Array, ListArray, RecordBatchIterator, StringArray, StructArray,
};
use super::*;
use arrow_schema::Fields as ArrowFields;
use lance_core::utils::tempfile::TempStrDir;
use lance_file::version::LanceFileVersion;
use rstest::rstest;
fn require_send<T: Send>(t: T) -> T {
t
}
#[tokio::test]
async fn test_append_columns_exprs() -> Result<()> {
let num_rows = 5;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..num_rows as i32))],
)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
data_storage_version: Some(LanceFileVersion::Legacy),
..Default::default()
}),
)
.await?;
dataset.validate().await?;
let fut = dataset.add_columns(
NewColumnTransform::SqlExpressions(vec![("id".into(), "id + 1".into())]),
None,
None,
);
let res = require_send(fut).await;
assert!(matches!(res, Err(Error::InvalidInput { .. })));
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![("value".into(), "2 * random()".into())]),
None,
None,
)
.await?;
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![("double_id".into(), "2 * id".into())]),
None,
None,
)
.await?;
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![(
"triple_id".into(),
"id + double_id".into(),
)]),
None,
None,
)
.await?;
dataset.validate().await?;
let data = dataset.scan().try_into_batch().await?;
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new("value", DataType::Float64, true),
ArrowField::new("double_id", DataType::Int32, false),
ArrowField::new("triple_id", DataType::Int32, false),
]);
assert_eq!(data.schema().as_ref(), &expected_schema);
assert_eq!(data.num_rows(), num_rows);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_append_columns_udf(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
use arrow_array::Float64Array;
let num_rows = 5;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..num_rows as i32))],
)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
dataset.validate().await?;
let transforms = NewColumnTransform::BatchUDF(BatchUDF {
mapper: Box::new(|_| unimplemented!()),
output_schema: Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)])),
result_checkpoint: None,
});
let res = dataset.add_columns(transforms, None, None).await;
assert!(matches!(res, Err(Error::InvalidInput { .. })));
let output_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"value",
DataType::Float64,
true,
)]));
let output_schema_ref = output_schema.clone();
let mapper = move |batch: &RecordBatch| {
Ok(RecordBatch::try_new(
output_schema_ref.clone(),
vec![Arc::new(Float64Array::from_iter_values(
(0..batch.num_rows()).map(|i| i as f64),
))],
)?)
};
let transforms = NewColumnTransform::BatchUDF(BatchUDF {
mapper: Box::new(mapper),
output_schema,
result_checkpoint: None,
});
dataset.add_columns(transforms, None, None).await?;
let output_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"double_id",
DataType::Int32,
false,
)]));
let output_schema_ref = output_schema.clone();
let mapper = move |batch: &RecordBatch| {
let id = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
Ok(RecordBatch::try_new(
output_schema_ref.clone(),
vec![Arc::new(Int32Array::from_iter_values(
id.values().iter().map(|i| i * 2),
))],
)?)
};
let transforms = NewColumnTransform::BatchUDF(BatchUDF {
mapper: Box::new(mapper),
output_schema,
result_checkpoint: None,
});
dataset.add_columns(transforms, None, None).await?;
dataset.validate().await?;
let data = dataset.scan().try_into_batch().await?;
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new("value", DataType::Float64, true),
ArrowField::new("double_id", DataType::Int32, false),
]);
assert_eq!(data.schema().as_ref(), &expected_schema);
assert_eq!(data.num_rows(), num_rows);
Ok(())
}
#[tokio::test]
async fn test_append_columns_udf_cache() -> Result<()> {
let num_rows = 100;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
max_rows_per_file: 50,
max_rows_per_group: 25,
data_storage_version: Some(LanceFileVersion::Legacy),
..Default::default()
}),
)
.await?;
dataset.validate().await?;
#[derive(Default)]
struct RequestCounter {
pub get_batch_requests: Mutex<Vec<BatchInfo>>,
pub insert_batch_requests: Mutex<Vec<BatchInfo>>,
pub get_fragment_requests: Mutex<Vec<u32>>,
pub insert_fragment_requests: Mutex<Vec<u32>>,
}
impl UDFCheckpointStore for RequestCounter {
fn get_batch(&self, info: &BatchInfo) -> Result<Option<RecordBatch>> {
self.get_batch_requests.lock().unwrap().push(info.clone());
if info.fragment_id == 1 && info.batch_index == 0 {
Ok(Some(RecordBatch::try_new(
Arc::new(ArrowSchema::new(vec![ArrowField::new(
"double_id",
DataType::Int32,
false,
)])),
vec![Arc::new(Int32Array::from_iter_values(50..75))],
)?))
} else {
Ok(None)
}
}
fn insert_batch(&self, info: BatchInfo, _value: RecordBatch) -> Result<()> {
self.insert_batch_requests.lock().unwrap().push(info);
Ok(())
}
fn get_fragment(&self, fragment_id: u32) -> Result<Option<Fragment>> {
self.get_fragment_requests.lock().unwrap().push(fragment_id);
if fragment_id == 0 {
Ok(Some(Fragment {
files: vec![],
id: 0,
deletion_file: None,
row_id_meta: None,
physical_rows: Some(50),
last_updated_at_version_meta: None,
created_at_version_meta: None,
}))
} else {
Ok(None)
}
}
fn insert_fragment(&self, fragment: Fragment) -> Result<()> {
self.insert_fragment_requests
.lock()
.unwrap()
.push(fragment.id as u32);
Ok(())
}
}
let request_counter = Arc::new(RequestCounter::default());
let output_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"double_id",
DataType::Int32,
false,
)]));
let output_schema_ref = output_schema.clone();
let mapper = move |batch: &RecordBatch| {
let id = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
Ok(RecordBatch::try_new(
output_schema_ref.clone(),
vec![Arc::new(Int32Array::from_iter_values(
id.values().iter().map(|i| i * 2),
))],
)?)
};
let transforms = NewColumnTransform::BatchUDF(BatchUDF {
mapper: Box::new(mapper),
output_schema,
result_checkpoint: Some(request_counter.clone()),
});
dataset.add_columns(transforms, None, None).await?;
assert_eq!(
request_counter
.get_fragment_requests
.lock()
.unwrap()
.as_slice(),
&[0, 1]
);
assert_eq!(
request_counter
.insert_fragment_requests
.lock()
.unwrap()
.as_slice(),
&[1]
);
assert_eq!(
request_counter
.get_batch_requests
.lock()
.unwrap()
.as_slice(),
&[
BatchInfo {
fragment_id: 1,
batch_index: 0,
},
BatchInfo {
fragment_id: 1,
batch_index: 1,
},
]
);
assert_eq!(
request_counter
.insert_batch_requests
.lock()
.unwrap()
.as_slice(),
&[BatchInfo {
fragment_id: 1,
batch_index: 1,
},]
);
Ok(())
}
#[tokio::test]
async fn test_add_column_all_nulls() -> Result<()> {
let num_rows = 100;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
max_rows_per_file: 50,
max_rows_per_group: 25,
data_storage_version: Some(LanceFileVersion::Stable),
..Default::default()
}),
)
.await?;
dataset.validate().await?;
dataset
.add_columns(
NewColumnTransform::AllNulls(Arc::new(ArrowSchema::new(vec![ArrowField::new(
"nulls",
DataType::Int32,
true,
)]))),
None,
None,
)
.await?;
let data = dataset.scan().try_into_batch().await?;
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new("nulls", DataType::Int32, true),
]);
assert_eq!(data.schema().as_ref(), &expected_schema);
assert_eq!(data.num_rows(), num_rows as usize);
let err =
dataset
.add_columns(
NewColumnTransform::AllNulls(Arc::new(ArrowSchema::new(vec![
ArrowField::new("non_nulls", DataType::Int32, false),
]))),
None,
None,
)
.await
.unwrap_err();
assert!(
err.to_string()
.contains("All-null columns must be nullable.")
);
let data = dataset.scan().try_into_batch().await?;
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new("nulls", DataType::Int32, true),
]);
assert_eq!(data.schema().as_ref(), &expected_schema);
assert_eq!(data.num_rows(), num_rows as usize);
Ok(())
}
#[tokio::test]
async fn test_add_column_all_nulls_legacy() -> Result<()> {
let num_rows = 100;
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
max_rows_per_file: 50,
max_rows_per_group: 25,
data_storage_version: Some(LanceFileVersion::Legacy),
..Default::default()
}),
)
.await?;
dataset.validate().await?;
let err =
dataset
.add_columns(
NewColumnTransform::AllNulls(Arc::new(ArrowSchema::new(vec![
ArrowField::new("nulls", DataType::Int32, true),
]))),
None,
None,
)
.await
.unwrap_err();
assert!(
err.to_string()
.contains("Cannot add all-null columns to legacy dataset version")
);
Ok(())
}
async fn prepare_dataset(version: LanceFileVersion) -> Result<Dataset> {
let person_struct_type = DataType::Struct(ArrowFields::from(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("age", DataType::Int32, false),
ArrowField::new("city", DataType::Utf8, false),
]));
let list_of_struct_type = DataType::List(Arc::new(ArrowField::new(
"item",
person_struct_type.clone(),
false,
)));
let schema = Arc::new(ArrowSchema::new_with_metadata(
vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new("people", list_of_struct_type.clone(), false),
],
HashMap::<String, String>::new(),
));
let all_names = StringArray::from(vec!["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]);
let all_ages = Int32Array::from(vec![25, 30, 35, 28, 32, 40]);
let all_cities = StringArray::from(vec![
"Beijing",
"Shanghai",
"Guangzhou",
"Shenzhen",
"Hangzhou",
"Chengdu",
]);
let all_struct = StructArray::new(
ArrowFields::from(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("age", DataType::Int32, false),
ArrowField::new("city", DataType::Utf8, false),
]),
vec![
Arc::new(all_names) as ArrayRef,
Arc::new(all_ages) as ArrayRef,
Arc::new(all_cities) as ArrayRef,
],
None,
);
let all_people = ListArray::new(
Arc::new(ArrowField::new("item", person_struct_type, false)),
arrow_buffer::OffsetBuffer::new(arrow_buffer::ScalarBuffer::from(vec![
0i32, 2i32, 5i32, 6i32,
])),
Arc::new(all_struct),
None,
);
let ids = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(ids) as ArrayRef, Arc::new(all_people) as ArrayRef],
)?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let dataset = Dataset::write(
reader,
"memory://test",
Some(WriteParams {
data_storage_version: Some(version),
..Default::default()
}),
)
.await?;
assert_eq!(dataset.schema().fields.len(), 2);
assert_eq!(dataset.schema().fields[0].name, "id");
assert_eq!(dataset.schema().fields[1].name, "people");
Ok(dataset)
}
#[rstest]
#[tokio::test]
async fn test_drop_list_struct_sub_columns_legacy(
#[values(
LanceFileVersion::Legacy,
LanceFileVersion::V2_0,
LanceFileVersion::V2_1
)]
version: LanceFileVersion,
) -> Result<()> {
let mut dataset = prepare_dataset(version).await?;
dataset.drop_columns(&["people.item.city"]).await?;
dataset.validate().await?;
assert_eq!(dataset.schema().fields.len(), 1);
assert_eq!(dataset.schema().fields[0].name, "id");
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_drop_list_struct_sub_columns(
#[values(LanceFileVersion::V2_2)] version: LanceFileVersion,
) -> Result<()> {
let mut dataset = prepare_dataset(version).await?;
dataset.drop_columns(&["people.item.city"]).await?;
dataset.validate().await?;
let expected_schema = ArrowSchema::new_with_metadata(
vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new(
"people",
DataType::List(Arc::new(ArrowField::new(
"item",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("name", DataType::Utf8, false),
ArrowField::new("age", DataType::Int32, false),
])),
false,
))),
false,
),
],
HashMap::<String, String>::new(),
);
assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema);
let batch = dataset.scan().try_into_batch().await?;
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 2);
let list_array = batch
.column(1)
.as_any()
.downcast_ref::<ListArray>()
.unwrap();
let list_value = list_array.value(0);
let struct_array = list_value.as_any().downcast_ref::<StructArray>().unwrap();
assert!(struct_array.column_by_name("city").is_none());
Ok(())
}
#[test]
fn test_exclude_fields() {
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("f1", DataType::Utf8, true),
ArrowField::new("f2", DataType::Boolean, false),
ArrowField::new("f3", DataType::Float32, false),
])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
let schema = Schema::try_from(&arrow_schema).unwrap();
let projection = schema.project(&["a", "b.f2", "b.f3"]).unwrap();
let excluded = exclude(&schema, &projection, &LanceFileVersion::V2_2).unwrap();
let expected_arrow_schema = ArrowSchema::new(vec![
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"f1",
DataType::Utf8,
true,
)])),
true,
),
ArrowField::new("c", DataType::Float64, false),
]);
assert_eq!(ArrowSchema::from(&excluded), expected_arrow_schema);
}
#[rstest]
#[tokio::test]
async fn test_rename_columns(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
use std::collections::HashMap;
use arrow_array::{ArrayRef, StructArray};
let metadata: HashMap<String, String> = [("k1".into(), "v1".into())].into();
let schema = Arc::new(ArrowSchema::new_with_metadata(
vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true,
)])),
true,
),
],
metadata.clone(),
));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StructArray::from(vec![(
Arc::new(ArrowField::new("c", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef,
)])),
],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let batches = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let mut dataset = Dataset::write(
batches,
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
let original_fragments = dataset.fragments().to_vec();
dataset
.alter_columns(&[ColumnAlteration::new("a".into())
.rename("x".into())
.set_nullable(true)])
.await?;
dataset.validate().await?;
assert_eq!(dataset.manifest.version, 2);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);
let expected_schema = ArrowSchema::new_with_metadata(
vec![
ArrowField::new("x", DataType::Int32, true),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true,
)])),
true,
),
],
metadata.clone(),
);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);
let err = dataset
.alter_columns(&[ColumnAlteration::new("b".into()).rename("x".into())])
.await
.unwrap_err();
assert!(err.to_string().contains("Duplicate field name \"x\""));
dataset
.alter_columns(&[ColumnAlteration::new("b.c".into()).rename("d".into())])
.await?;
dataset.validate().await?;
assert_eq!(dataset.manifest.version, 3);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);
let expected_schema = ArrowSchema::new_with_metadata(
vec![
ArrowField::new("x", DataType::Int32, true),
ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"d",
DataType::Int32,
true,
)])),
true,
),
],
metadata.clone(),
);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_set_not_null_succeeds(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"a",
DataType::Int32,
true,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values([1, 2, 3]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
let original_fragments = dataset.fragments().to_vec();
dataset
.alter_columns(&[ColumnAlteration::new("a".into()).set_nullable(false)])
.await?;
dataset.validate().await?;
assert_eq!(dataset.manifest.version, 2);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new("a", DataType::Int32, false)])
);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_set_not_null_succeeds_nested(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
use arrow_array::{ArrayRef, StructArray};
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true,
)])),
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StructArray::from(vec![(
Arc::new(ArrowField::new("c", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
)]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
let original_fragments = dataset.fragments().to_vec();
dataset
.alter_columns(&[ColumnAlteration::new("b.c".into()).set_nullable(false)])
.await?;
dataset.validate().await?;
assert_eq!(dataset.fragments().as_ref(), &original_fragments);
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
false
)])),
false
)])
);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_set_not_null_fails_with_nulls(
#[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion,
) -> Result<()> {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"a",
DataType::Int32,
true,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
let err = dataset
.alter_columns(&[ColumnAlteration::new("a".into()).set_nullable(false)])
.await
.unwrap_err();
assert!(err.to_string().contains("contains NULL values"));
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new("a", DataType::Int32, true)])
);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_set_not_null_fails_with_nulls_nested(
#[values(LanceFileVersion::Stable)] data_storage_version: LanceFileVersion,
) -> Result<()> {
use arrow_array::{ArrayRef, StructArray};
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true,
)])),
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StructArray::from(vec![(
Arc::new(ArrowField::new("c", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef,
)]))],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch)], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
let err = dataset
.alter_columns(&[ColumnAlteration::new("b.c".into()).set_nullable(false)])
.await
.unwrap_err();
assert!(err.to_string().contains("contains NULL values"));
assert_eq!(
&ArrowSchema::from(dataset.schema()),
&ArrowSchema::new(vec![ArrowField::new(
"b",
DataType::Struct(ArrowFields::from(vec![ArrowField::new(
"c",
DataType::Int32,
true
)])),
false
)])
);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_cast_column(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
use arrow::datatypes::{Int32Type, Int64Type};
use arrow_array::{Float16Array, Float32Array, Int64Array, ListArray};
use half::f16;
use lance_arrow::FixedSizeListArrayExt;
use lance_index::{DatasetIndexExt, IndexType, scalar::ScalarIndexParams};
use lance_linalg::distance::MetricType;
use lance_testing::datagen::generate_random_array;
use crate::index::vector::VectorIndexParams;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, false),
ArrowField::new("f", DataType::Float32, false),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
128,
),
false,
),
ArrowField::new("l", DataType::new_list(DataType::Int32, true), true),
]));
let nrows = 512;
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..nrows)),
Arc::new(Float32Array::from_iter_values((0..nrows).map(|i| i as f32))),
Arc::new(
<arrow_array::FixedSizeListArray as FixedSizeListArrayExt>::try_new_from_values(
generate_random_array(128 * nrows as usize),
128,
)
.unwrap(),
),
Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(
(0..nrows).map(|i| Some(vec![Some(i), Some(i + 1)])),
)),
],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()),
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
let params = VectorIndexParams::ivf_pq(10, 8, 2, MetricType::L2, 50);
dataset
.create_index(&["vec"], IndexType::Vector, None, ¶ms, false)
.await?;
dataset
.create_index(
&["i"],
IndexType::Scalar,
None,
&ScalarIndexParams::default(),
false,
)
.await?;
dataset.validate().await?;
let indices = dataset.load_indices().await?;
assert_eq!(indices.len(), 2);
dataset
.alter_columns(&[ColumnAlteration::new("f".into())
.cast_to(DataType::Float16)
.set_nullable(true)])
.await?;
dataset.validate().await?;
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, false),
ArrowField::new("f", DataType::Float16, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
128,
),
false,
),
ArrowField::new("l", DataType::new_list(DataType::Int32, true), true),
]);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);
dataset.fragments().iter().for_each(|f| {
assert_eq!(f.files.len(), 2);
});
dataset
.alter_columns(&[ColumnAlteration::new("i".into()).cast_to(DataType::Int64)])
.await?;
dataset.validate().await?;
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int64, false),
ArrowField::new("f", DataType::Float16, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
128,
),
false,
),
ArrowField::new("l", DataType::new_list(DataType::Int32, true), true),
]);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);
let indices = dataset.load_indices().await?;
assert_eq!(indices.len(), 1);
dataset.fragments().iter().for_each(|f| {
assert_eq!(f.files.len(), 3);
});
dataset
.alter_columns(&[
ColumnAlteration::new("vec".into()).cast_to(DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float16, true)),
128,
)),
])
.await?;
dataset.validate().await?;
dataset
.alter_columns(&[ColumnAlteration::new("l".into())
.cast_to(DataType::new_list(DataType::Int64, true))])
.await?;
dataset.validate().await?;
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int64, false),
ArrowField::new("f", DataType::Float16, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float16, true)),
128,
),
false,
),
ArrowField::new("l", DataType::new_list(DataType::Int64, true), true),
]);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);
let indices = dataset.load_indices().await?;
assert_eq!(indices.len(), 0);
dataset.fragments().iter().for_each(|f| {
assert_eq!(f.files.len(), 4);
});
let expected_data = RecordBatch::try_new(
Arc::new(expected_schema),
vec![
Arc::new(Int64Array::from_iter_values(0..nrows as i64)),
Arc::new(Float16Array::from_iter_values(
(0..nrows).map(|i| f16::from_f32(i as f32)),
)),
lance_arrow::cast::cast_with_options(
batch["vec"].as_ref(),
&DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float16, true)),
128,
),
&Default::default(),
)?,
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(
(0..nrows as i64).map(|i| Some(vec![Some(i), Some(i + 1)])),
)),
],
)?;
let actual_data = dataset.scan().try_into_batch().await?;
assert_eq!(actual_data, expected_data);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_drop_columns(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
use std::collections::HashMap;
use arrow_array::{ArrayRef, Float32Array, StructArray};
let metadata: HashMap<String, String> = [("k1".into(), "v1".into())].into();
let schema = Arc::new(ArrowSchema::new_with_metadata(
vec![
ArrowField::new("i", DataType::Int32, false),
ArrowField::new(
"s",
DataType::Struct(ArrowFields::from(vec![
ArrowField::new("d", DataType::Int32, true),
ArrowField::new("l", DataType::Int32, true),
])),
true,
),
ArrowField::new("x", DataType::Float32, false),
],
metadata.clone(),
));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StructArray::from(vec![
(
Arc::new(ArrowField::new("d", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef,
),
(
Arc::new(ArrowField::new("l", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2])),
),
])),
Arc::new(Float32Array::from(vec![1.0, 2.0])),
],
)?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let batches = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let mut dataset = Dataset::write(
batches,
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
let lance_schema = dataset.schema().clone();
let original_fragments = dataset.fragments().to_vec();
dataset.drop_columns(&["x"]).await?;
dataset.validate().await?;
let expected_schema = lance_schema.project(&["i", "s"])?;
assert_eq!(dataset.schema(), &expected_schema);
assert_eq!(dataset.version().version, 2);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);
dataset.drop_columns(&["s.d"]).await?;
dataset.validate().await?;
let expected_schema = expected_schema.project(&["i", "s.l"])?;
assert_eq!(dataset.schema(), &expected_schema);
let expected_data = RecordBatch::try_new(
Arc::new(ArrowSchema::from(&expected_schema)),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StructArray::from(vec![(
Arc::new(ArrowField::new("l", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef,
)])),
],
)?;
let actual_data = dataset.scan().try_into_batch().await?;
assert_eq!(actual_data, expected_data);
assert_eq!(dataset.version().version, 3);
assert_eq!(dataset.fragments().as_ref(), &original_fragments);
Ok(())
}
#[rstest]
#[tokio::test]
async fn test_drop_add_columns(
#[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)]
data_storage_version: LanceFileVersion,
) -> Result<()> {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"i",
DataType::Int32,
false,
)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1, 2]))])?;
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let batches = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let mut dataset = Dataset::write(
batches,
test_uri,
Some(WriteParams {
data_storage_version: Some(data_storage_version),
..Default::default()
}),
)
.await?;
assert_eq!(dataset.manifest.max_field_id(), 0);
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![("x".into(), "i + 1".into())]),
Some(vec!["i".into()]),
None,
)
.await?;
assert_eq!(dataset.manifest.max_field_id(), 1);
dataset.drop_columns(&["x"]).await?;
assert_eq!(dataset.manifest.max_field_id(), 0);
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![("y".into(), "2 * i".into())]),
Some(vec!["i".into()]),
None,
)
.await?;
assert_eq!(dataset.manifest.max_field_id(), 1);
let data = dataset.scan().try_into_batch().await?;
let expected_data = RecordBatch::try_new(
Arc::new(schema.try_with_column(ArrowField::new("y", DataType::Int32, false))?),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(Int32Array::from(vec![2, 4])),
],
)?;
assert_eq!(data, expected_data);
dataset.drop_columns(&["y"]).await?;
assert_eq!(dataset.manifest.max_field_id(), 0);
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![
("a".into(), "i + 3".into()),
("b".into(), "i + 7".into()),
]),
Some(vec!["i".into()]),
None,
)
.await?;
assert_eq!(dataset.manifest.max_field_id(), 2);
dataset.drop_columns(&["b"]).await?;
assert_eq!(dataset.manifest.max_field_id(), 2);
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![("c".into(), "i + 11".into())]),
Some(vec!["i".into()]),
None,
)
.await?;
assert_eq!(dataset.manifest.max_field_id(), 3);
let data = dataset.scan().try_into_batch().await?;
let expected_schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, false),
ArrowField::new("a", DataType::Int32, false),
ArrowField::new("c", DataType::Int32, false),
]));
let expected_data = RecordBatch::try_new(
expected_schema,
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(Int32Array::from(vec![4, 5])),
Arc::new(Int32Array::from(vec![12, 13])),
],
)?;
assert_eq!(data, expected_data);
Ok(())
}
#[tokio::test]
async fn test_new_column_sql_to_all_nulls_transform_optimizer() {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"a",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter(0..100))],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
max_rows_per_file: 50,
max_rows_per_group: 25,
data_storage_version: Some(LanceFileVersion::Stable),
..Default::default()
}),
)
.await
.unwrap();
dataset.validate().await.unwrap();
let manifest_before = dataset.manifest.clone();
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![(
"b".to_string(),
"CAST(NULL AS int)".to_string(),
)]),
None,
None,
)
.await
.unwrap();
let manifest_after = dataset.manifest.clone();
assert_eq!(&manifest_before.fragments, &manifest_after.fragments);
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new("b", DataType::Int32, true),
]);
assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema);
}
#[tokio::test]
async fn test_new_column_sql_to_all_nulls_transform_optimizer_legacy() {
let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"a",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter(0..100))],
)
.unwrap();
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
let test_dir = TempStrDir::default();
let test_uri = &test_dir;
let mut dataset = Dataset::write(
reader,
test_uri,
Some(WriteParams {
max_rows_per_file: 50,
max_rows_per_group: 25,
data_storage_version: Some(LanceFileVersion::Legacy),
..Default::default()
}),
)
.await
.unwrap();
dataset.validate().await.unwrap();
dataset
.add_columns(
NewColumnTransform::SqlExpressions(vec![(
"b".to_string(),
"CAST(NULL AS int)".to_string(),
)]),
None,
None,
)
.await
.unwrap();
let expected_schema = ArrowSchema::new(vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new("b", DataType::Int32, true),
]);
assert_eq!(ArrowSchema::from(dataset.schema()), expected_schema);
}
#[test]
fn test_check_field_conflict() {
let field1 = ArrowField::new(
"test",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
);
let field2 = ArrowField::new(
"test",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
);
let field2 = ArrowField::new(
"test",
DataType::Struct(vec![ArrowField::new("b", DataType::Int32, false)].into()),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_ok());
let inner_struct1 = ArrowField::new(
"inner",
DataType::Struct(vec![ArrowField::new("x", DataType::Int32, false)].into()),
false,
);
let inner_struct2 = ArrowField::new(
"inner",
DataType::Struct(vec![ArrowField::new("x", DataType::Int32, false)].into()),
false,
);
let field1 = ArrowField::new("test", DataType::Struct(vec![inner_struct1].into()), false);
let field2 = ArrowField::new("test", DataType::Struct(vec![inner_struct2].into()), false);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new("test1", DataType::Int32, false);
let field2 = ArrowField::new("test2", DataType::Int32, false);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_ok());
let field1 = ArrowField::new("test", DataType::Int32, false);
let field2 = ArrowField::new("test", DataType::Int32, false);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new("test", DataType::Int32, false);
let field2 = ArrowField::new("test", DataType::Float64, false);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::Struct(
vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new("b", DataType::Utf8, false),
]
.into(),
),
false,
);
let field2 = ArrowField::new(
"test",
DataType::Struct(
vec![
ArrowField::new("a", DataType::Int32, false),
ArrowField::new("c", DataType::Utf8, false),
]
.into(),
),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))),
false,
);
let field2 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
))),
false,
);
let field2 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
))),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
))),
false,
);
let field2 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("b", DataType::Int32, false)].into()),
false,
))),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_ok());
let field1 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
))),
false,
);
let field2 = ArrowField::new(
"test",
DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::FixedSizeList(
Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
)),
2,
),
false,
);
let field2 = ArrowField::new(
"test",
DataType::FixedSizeList(
Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
)),
2,
),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::FixedSizeList(
Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
)),
2,
),
false,
);
let field2 = ArrowField::new(
"test",
DataType::FixedSizeList(
Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("b", DataType::Int32, false)].into()),
false,
)),
2,
),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_ok());
let field1 = ArrowField::new(
"test",
DataType::LargeList(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
))),
false,
);
let field2 = ArrowField::new(
"test",
DataType::LargeList(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
))),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_err());
let field1 = ArrowField::new(
"test",
DataType::LargeList(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("a", DataType::Int32, false)].into()),
false,
))),
false,
);
let field2 = ArrowField::new(
"test",
DataType::LargeList(Arc::new(ArrowField::new(
"item",
DataType::Struct(vec![ArrowField::new("b", DataType::Int32, false)].into()),
false,
))),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_ok());
let mut packed_meta = HashMap::new();
packed_meta.insert(PACKED_STRUCT_META_KEY.to_string(), "true".to_string());
let packed_field = ArrowField::new(
"packed",
DataType::Struct(vec![ArrowField::new("foo", DataType::Int32, false)].into()),
false,
)
.with_metadata(packed_meta.clone());
let field1 = ArrowField::new("test", DataType::Struct(vec![packed_field].into()), false);
let field2 = ArrowField::new(
"test",
DataType::Struct(vec![ArrowField::new("b", DataType::Int32, false)].into()),
false,
);
assert!(check_field_conflict(&field1, &field2, &LanceFileVersion::V2_2).is_ok());
let new_packed_field = ArrowField::new(
"new_packed",
DataType::Struct(vec![ArrowField::new("foo", DataType::Int32, false)].into()),
false,
)
.with_metadata(packed_meta.clone());
let field3 = ArrowField::new(
"test",
DataType::Struct(vec![new_packed_field].into()),
false,
);
assert!(check_field_conflict(&field1, &field3, &LanceFileVersion::V2_2).is_ok());
let conflict_field = ArrowField::new(
"packed",
DataType::Struct(vec![ArrowField::new("new_col", DataType::Int32, false)].into()),
false,
)
.with_metadata(packed_meta);
let field4 = ArrowField::new("test", DataType::Struct(vec![conflict_field].into()), false);
assert!(check_field_conflict(&field1, &field4, &LanceFileVersion::V2_2).is_err());
}
}