use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{RecordBatch, RecordBatchIterator};
use datafusion::execution::SendableRecordBatchStream;
use humantime::format_duration;
use lance_core::datatypes::{NullabilityComparison, Schema, SchemaCompareOptions};
use lance_core::utils::tracing::{DATASET_WRITING_EVENT, TRACE_DATASET_EVENTS};
use lance_core::{ROW_ADDR, ROW_ID, ROW_OFFSET};
use lance_datafusion::utils::StreamingWriteSource;
use lance_file::version::LanceFileVersion;
use lance_io::object_store::ObjectStore;
use lance_table::feature_flags::can_write_dataset;
use lance_table::format::Fragment;
use lance_table::io::commit::CommitHandler;
use object_store::path::Path;
use crate::Dataset;
use crate::dataset::ReadParams;
use crate::dataset::builder::DatasetBuilder;
use crate::dataset::transaction::{Operation, Transaction, TransactionBuilder};
use crate::dataset::write::{validate_and_resolve_target_bases, write_fragments_internal};
use crate::{Error, Result};
use tracing::info;
use super::WriteDestination;
use super::WriteMode;
use super::WriteParams;
use super::commit::CommitBuilder;
use super::resolve_commit_handler;
#[derive(Debug, Clone)]
pub struct InsertBuilder<'a> {
dest: WriteDestination<'a>,
params: Option<&'a WriteParams>,
}
impl<'a> InsertBuilder<'a> {
pub fn new(dest: impl Into<WriteDestination<'a>>) -> Self {
Self {
dest: dest.into(),
params: None,
}
}
pub fn with_params(mut self, params: &'a WriteParams) -> Self {
self.params = Some(params);
self
}
pub async fn execute(&self, data: Vec<RecordBatch>) -> Result<Dataset> {
let (transaction, context) = self.write_uncommitted_impl(data).await?;
Self::do_commit(&context, transaction).await
}
pub async fn execute_stream(&self, source: impl StreamingWriteSource) -> Result<Dataset> {
let (stream, schema) = source.into_stream_and_schema().await?;
self.execute_stream_impl(stream, schema).await
}
async fn execute_stream_impl(
&self,
stream: SendableRecordBatchStream,
schema: Schema,
) -> Result<Dataset> {
let (transaction, context) = self.write_uncommitted_stream_impl(stream, schema).await?;
Self::do_commit(&context, transaction).await
}
pub async fn execute_uncommitted(&self, data: Vec<RecordBatch>) -> Result<Transaction> {
self.write_uncommitted_impl(data).await.map(|(t, _)| t)
}
async fn do_commit(context: &WriteContext<'_>, transaction: Transaction) -> Result<Dataset> {
let mut commit_builder = CommitBuilder::new(context.dest.clone())
.use_stable_row_ids(context.params.enable_stable_row_ids)
.with_storage_format(context.storage_version)
.enable_v2_manifest_paths(context.params.enable_v2_manifest_paths)
.with_commit_handler(context.commit_handler.clone())
.with_object_store(context.object_store.clone())
.with_skip_auto_cleanup(context.params.skip_auto_cleanup);
if let Some(params) = context.params.store_params.as_ref() {
commit_builder = commit_builder.with_store_params(params.clone());
}
if let Some(session) = context.params.session.as_ref() {
commit_builder = commit_builder.with_session(session.clone());
}
commit_builder.execute(transaction).await
}
async fn write_uncommitted_impl(
&self,
data: Vec<RecordBatch>,
) -> Result<(Transaction, WriteContext<'_>)> {
if data.is_empty() {
return Err(Error::invalid_input_source("No data to write".into()));
}
let schema = data[0].schema();
for batch in data.iter().skip(1) {
if batch.schema() != schema {
return Err(Error::invalid_input_source(
"All record batches must have the same schema".into(),
));
}
}
let reader = RecordBatchIterator::new(data.into_iter().map(Ok), schema);
let (stream, schema) = reader.into_stream_and_schema().await?;
self.write_uncommitted_stream_impl(stream, schema).await
}
pub async fn execute_uncommitted_stream(
&self,
source: impl StreamingWriteSource,
) -> Result<Transaction> {
let (stream, schema) = source.into_stream_and_schema().await?;
let (transaction, _) = self.write_uncommitted_stream_impl(stream, schema).await?;
Ok(transaction)
}
async fn write_uncommitted_stream_impl(
&self,
stream: SendableRecordBatchStream,
schema: Schema,
) -> Result<(Transaction, WriteContext<'_>)> {
let mut context = self.resolve_context().await?;
info!(
target: TRACE_DATASET_EVENTS,
event=DATASET_WRITING_EVENT,
uri=context.dest.uri(),
mode=?context.params.mode
);
self.validate_write(&mut context, &schema)?;
let existing_base_paths = context.dest.dataset().map(|ds| &ds.manifest.base_paths);
let target_base_info =
validate_and_resolve_target_bases(&mut context.params, existing_base_paths).await?;
let (written_fragments, written_schema) = write_fragments_internal(
context.dest.dataset(),
context.object_store.clone(),
&context.base_path,
schema.clone(),
stream,
context.params.clone(),
target_base_info,
)
.await?;
let transaction = Self::build_transaction(written_schema, written_fragments, &context)?;
Ok((transaction, context))
}
fn build_transaction(
schema: Schema,
fragments: Vec<Fragment>,
context: &WriteContext<'_>,
) -> Result<Transaction> {
let operation = match context.params.mode {
WriteMode::Create => {
let mut upsert_values = HashMap::new();
if let Some(auto_cleanup_params) = context.params.auto_cleanup.as_ref() {
upsert_values.insert(
String::from("lance.auto_cleanup.interval"),
auto_cleanup_params.interval.to_string(),
);
let duration = auto_cleanup_params
.older_than
.to_std()
.map_err(|e| Error::invalid_input_source(e.into()))?;
upsert_values.insert(
String::from("lance.auto_cleanup.older_than"),
format_duration(duration).to_string(),
);
}
let config_upsert_values = if upsert_values.is_empty() {
None
} else {
Some(upsert_values)
};
Operation::Overwrite {
schema,
fragments,
config_upsert_values,
initial_bases: context.params.initial_bases.clone(),
}
}
WriteMode::Overwrite => Operation::Overwrite {
schema,
fragments,
config_upsert_values: None,
initial_bases: context.params.initial_bases.clone(),
},
WriteMode::Append => Operation::Append { fragments },
};
let transaction = TransactionBuilder::new(
context
.dest
.dataset()
.map(|ds| ds.manifest.version)
.unwrap_or(0),
operation,
)
.transaction_properties(context.params.transaction_properties.clone())
.build();
Ok(transaction)
}
fn validate_write(&self, context: &mut WriteContext, data_schema: &Schema) -> Result<()> {
match (&context.params.mode, &context.dest) {
(WriteMode::Create, WriteDestination::Dataset(ds)) => {
return Err(Error::dataset_already_exists(ds.uri.clone()));
}
(WriteMode::Append | WriteMode::Overwrite, WriteDestination::Uri(uri)) => {
log::warn!("No existing dataset at {uri}, it will be created");
context.params.mode = WriteMode::Create;
}
_ => {}
}
if matches!(context.params.mode, WriteMode::Append)
&& let WriteDestination::Dataset(dataset) = &context.dest
{
if context.params.enable_stable_row_ids != dataset.manifest.uses_stable_row_ids() {
log::info!(
"Ignoring user provided stable row ids setting of {}, dataset already has it set to {}",
context.params.enable_stable_row_ids,
dataset.manifest.uses_stable_row_ids()
);
context.params.enable_stable_row_ids = dataset.manifest.uses_stable_row_ids();
}
let schema_cmp_opts = SchemaCompareOptions {
compare_dictionary: dataset.manifest.should_use_legacy_format(),
compare_nullability: NullabilityComparison::Ignore,
allow_missing_if_nullable: true,
ignore_field_order: true,
..Default::default()
};
data_schema.check_compatible(dataset.schema(), &schema_cmp_opts)?;
}
for field in data_schema.fields.iter() {
if field.name == ROW_ID || field.name == ROW_ADDR || field.name == ROW_OFFSET {
return Err(Error::invalid_input_source(
format!(
"The column {} is a reserved name and cannot be used in a Lance dataset",
field.name
)
.into(),
));
}
}
if let WriteDestination::Dataset(dataset) = &context.dest
&& !can_write_dataset(dataset.manifest.writer_feature_flags)
{
let message = format!(
"This dataset cannot be written by this version of Lance. \
Please upgrade Lance to write to this dataset.\n Flags: {}",
dataset.manifest.writer_feature_flags
);
return Err(Error::not_supported_source(message.into()));
}
Ok(())
}
async fn resolve_context(&self) -> Result<WriteContext<'a>> {
let params = self.params.cloned().unwrap_or_default();
let (object_store, base_path, commit_handler) = match &self.dest {
WriteDestination::Dataset(dataset) => (
dataset.object_store.clone(),
dataset.base.clone(),
dataset.commit_handler.clone(),
),
WriteDestination::Uri(uri) => {
let registry = params
.session
.as_ref()
.map(|s| s.store_registry())
.unwrap_or_else(|| Arc::new(Default::default()));
let (object_store, base_path) = ObjectStore::from_uri_and_params(
registry,
uri,
¶ms.store_params.clone().unwrap_or_default(),
)
.await?;
let commit_handler = resolve_commit_handler(
uri,
params.commit_handler.clone(),
¶ms.store_params,
)
.await?;
(object_store, base_path, commit_handler)
}
};
let dest = match &self.dest {
WriteDestination::Dataset(dataset) => WriteDestination::Dataset(dataset.clone()),
WriteDestination::Uri(uri) => {
let builder = DatasetBuilder::from_uri(uri).with_read_params(ReadParams {
store_options: params.store_params.clone(),
commit_handler: params.commit_handler.clone(),
session: params.session.clone(),
..Default::default()
});
match builder.load().await {
Ok(dataset) => WriteDestination::Dataset(Arc::new(dataset)),
Err(Error::DatasetNotFound { .. } | Error::NotFound { .. }) => {
WriteDestination::Uri(uri)
}
Err(e) => return Err(e),
}
}
};
let storage_version = match (¶ms.mode, &dest) {
(WriteMode::Overwrite, WriteDestination::Dataset(dataset)) => {
params.data_storage_version.map(Ok).unwrap_or_else(|| {
let m = dataset.manifest.as_ref();
m.data_storage_format.lance_file_version()
})?
}
(_, WriteDestination::Dataset(dataset)) => {
let m = dataset.manifest.as_ref();
m.data_storage_format.lance_file_version()?
}
(_, WriteDestination::Uri(_)) => params.storage_version_or_default(),
};
Ok(WriteContext {
params,
dest,
object_store,
base_path,
commit_handler,
storage_version,
})
}
}
#[derive(Debug)]
struct WriteContext<'a> {
params: WriteParams,
dest: WriteDestination<'a>,
object_store: Arc<ObjectStore>,
base_path: Path,
commit_handler: Arc<dyn CommitHandler>,
storage_version: LanceFileVersion,
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use arrow_array::{BinaryArray, Int32Array, RecordBatchReader, StructArray};
use arrow_schema::{ArrowError, DataType, Field, Schema};
use lance_arrow::BLOB_META_KEY;
use crate::session::Session;
use super::*;
#[tokio::test]
async fn test_pass_session() {
let session = Arc::new(Session::new(0, 0, Default::default()));
let dataset = InsertBuilder::new("memory://")
.with_params(&WriteParams {
session: Some(session.clone()),
..Default::default()
})
.execute_stream(RecordBatchIterator::new(
vec![],
Arc::new(Schema::new(vec![Field::new("col", DataType::Int32, false)])),
))
.await
.unwrap();
assert_eq!(Arc::as_ptr(&dataset.session()), Arc::as_ptr(&session));
}
#[tokio::test]
async fn test_write_empty_struct() {
let schema = Arc::new(Schema::new(vec![Field::new(
"empties",
DataType::Struct(Vec::<Field>::new().into()),
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StructArray::new_empty_fields(1, None))],
)
.unwrap();
let dataset = InsertBuilder::new("memory://")
.execute_stream(RecordBatchIterator::new(vec![Ok(batch)], schema.clone()))
.await
.unwrap();
assert_eq!(
dataset
.count_rows(Some("empties IS NOT NULL".to_string()))
.await
.unwrap(),
1
);
}
#[tokio::test]
async fn allow_overwrite_to_v2_2_without_blob_upgrade() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![1]))])
.unwrap();
let dataset = InsertBuilder::new("memory://blob-version-guard")
.execute_stream(RecordBatchIterator::new(
vec![Ok(batch.clone())],
schema.clone(),
))
.await
.unwrap();
let dataset = Arc::new(dataset);
let params = WriteParams {
mode: WriteMode::Overwrite,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
};
let result = InsertBuilder::new(dataset.clone())
.with_params(¶ms)
.execute_stream(RecordBatchIterator::new(vec![Ok(batch)], schema.clone()))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn create_v2_2_dataset_rejects_legacy_blob_schema() {
let schema = Arc::new(Schema::new(vec![
Field::new("blob", DataType::Binary, false).with_metadata(HashMap::from([(
BLOB_META_KEY.to_string(),
"true".to_string(),
)])),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(BinaryArray::from(vec![Some(b"abc".as_slice())]))],
)
.unwrap();
let dataset = InsertBuilder::new("memory://forced-blob-v2")
.with_params(&WriteParams {
mode: WriteMode::Create,
data_storage_version: Some(LanceFileVersion::V2_2),
..Default::default()
})
.execute_stream(RecordBatchIterator::new(vec![Ok(batch)], schema.clone()))
.await;
let err = dataset.unwrap_err();
match err {
Error::InvalidInput { source, .. } => {
let message = source.to_string();
assert!(message.contains("Legacy blob columns"));
assert!(message.contains("lance.blob.v2"));
}
other => panic!("unexpected error: {other:?}"),
}
}
mod external_error {
use super::*;
use std::fmt;
#[derive(Debug)]
struct MyTestError {
code: i32,
details: String,
}
impl fmt::Display for MyTestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MyTestError({}): {}", self.code, self.details)
}
}
impl std::error::Error for MyTestError {}
fn create_failing_iterator(
schema: Arc<Schema>,
fail_at_batch: usize,
error_code: i32,
) -> impl Iterator<Item = std::result::Result<RecordBatch, ArrowError>> {
let mut batch_count = 0;
std::iter::from_fn(move || {
if batch_count >= 5 {
return None;
}
batch_count += 1;
if batch_count == fail_at_batch {
Some(Err(ArrowError::ExternalError(Box::new(MyTestError {
code: error_code,
details: format!("Failed at batch {}", batch_count),
}))))
} else {
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![batch_count as i32; 10]))],
)
.unwrap();
Some(Ok(batch))
}
})
}
#[tokio::test]
async fn test_insert_builder_preserves_external_error() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let error_code = 42;
let iter = create_failing_iterator(schema.clone(), 3, error_code);
let reader = RecordBatchIterator::new(iter, schema);
let result = InsertBuilder::new("memory://test_external_error")
.execute_stream(Box::new(reader) as Box<dyn RecordBatchReader + Send>)
.await;
match result {
Err(Error::External { source }) => {
let original = source
.downcast_ref::<MyTestError>()
.expect("Should be able to downcast to MyTestError");
assert_eq!(original.code, error_code);
assert!(original.details.contains("batch 3"));
}
Err(other) => panic!("Expected Error::External variant, got: {:?}", other),
Ok(_) => panic!("Expected error, got success"),
}
}
#[tokio::test]
async fn test_insert_builder_first_batch_error() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let error_code = 999;
let iter = std::iter::once(Err(ArrowError::ExternalError(Box::new(MyTestError {
code: error_code,
details: "immediate failure".to_string(),
}))));
let reader = RecordBatchIterator::new(iter, schema);
let result = InsertBuilder::new("memory://test_first_batch_error")
.execute_stream(Box::new(reader) as Box<dyn RecordBatchReader + Send>)
.await;
match result {
Err(Error::External { source }) => {
let original = source.downcast_ref::<MyTestError>().unwrap();
assert_eq!(original.code, error_code);
}
Err(other) => panic!("Expected External, got: {:?}", other),
Ok(_) => panic!("Expected error"),
}
}
}
}