use arrow_array::RecordBatch;
use chrono::TimeDelta;
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::{Stream, StreamExt, TryStreamExt};
use lance_arrow::BLOB_META_KEY;
use lance_core::datatypes::{
NullabilityComparison, OnMissing, OnTypeMismatch, SchemaCompareOptions,
};
use lance_core::error::LanceOptionExt;
use lance_core::utils::tempfile::TempDir;
use lance_core::utils::tracing::{AUDIT_MODE_CREATE, AUDIT_TYPE_DATA, TRACE_FILE_AUDIT};
use lance_core::{Error, Result, datatypes::Schema};
use lance_datafusion::chunker::{break_stream, chunk_stream};
use lance_datafusion::spill::{SpillReceiver, SpillSender, create_replay_spill};
use lance_datafusion::utils::StreamingWriteSource;
use lance_file::previous::writer::{
FileWriter as PreviousFileWriter, ManifestProvider as PreviousManifestProvider,
};
use lance_file::version::LanceFileVersion;
use lance_file::writer::{self as current_writer, FileWriterOptions};
use lance_io::object_store::{ObjectStore, ObjectStoreParams, ObjectStoreRegistry};
use lance_table::format::{BasePath, DataFile, Fragment};
use lance_table::io::commit::{CommitHandler, commit_handler_from_url};
use lance_table::io::manifest::ManifestDescribing;
use object_store::path::Path;
use std::collections::{HashMap, HashSet};
use std::num::NonZero;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use tracing::{info, instrument};
use crate::Dataset;
use crate::dataset::blob::{
BlobPreprocessor, ExternalBaseCandidate, ExternalBaseResolver, preprocess_blob_batches,
};
use crate::session::Session;
use super::DATA_DIR;
use super::fragment::write::generate_random_filename;
use super::progress::{NoopFragmentWriteProgress, WriteFragmentProgress};
use super::transaction::Transaction;
use super::utils::SchemaAdapter;
mod commit;
pub mod delete;
mod insert;
pub mod merge_insert;
mod retry;
pub mod update;
pub use commit::CommitBuilder;
pub use delete::{DeleteBuilder, DeleteResult};
pub use insert::InsertBuilder;
#[derive(Debug, Clone)]
pub enum WriteDestination<'a> {
Dataset(Arc<Dataset>),
Uri(&'a str),
}
impl WriteDestination<'_> {
pub fn dataset(&self) -> Option<&Dataset> {
match self {
WriteDestination::Dataset(dataset) => Some(dataset.as_ref()),
WriteDestination::Uri(_) => None,
}
}
pub fn uri(&self) -> String {
match self {
WriteDestination::Dataset(dataset) => dataset.uri.clone(),
WriteDestination::Uri(uri) => uri.to_string(),
}
}
}
impl From<Arc<Dataset>> for WriteDestination<'_> {
fn from(dataset: Arc<Dataset>) -> Self {
WriteDestination::Dataset(dataset)
}
}
impl<'a> From<&'a str> for WriteDestination<'a> {
fn from(uri: &'a str) -> Self {
WriteDestination::Uri(uri)
}
}
impl<'a> From<&'a String> for WriteDestination<'a> {
fn from(uri: &'a String) -> Self {
WriteDestination::Uri(uri.as_str())
}
}
impl<'a> From<&'a Path> for WriteDestination<'a> {
fn from(path: &'a Path) -> Self {
WriteDestination::Uri(path.as_ref())
}
}
#[derive(Debug, Clone, Copy)]
pub enum WriteMode {
Create,
Append,
Overwrite,
}
impl TryFrom<&str> for WriteMode {
type Error = Error;
fn try_from(value: &str) -> Result<Self> {
match value.to_lowercase().as_str() {
"create" => Ok(Self::Create),
"append" => Ok(Self::Append),
"overwrite" => Ok(Self::Overwrite),
_ => Err(Error::invalid_input(format!(
"Invalid write mode: {}",
value
))),
}
}
}
#[derive(Debug, Clone)]
pub struct AutoCleanupParams {
pub interval: usize,
pub older_than: TimeDelta,
}
impl Default for AutoCleanupParams {
fn default() -> Self {
Self {
interval: 20,
older_than: TimeDelta::days(14),
}
}
}
#[derive(Debug, Clone)]
pub struct WriteParams {
pub max_rows_per_file: usize,
pub max_rows_per_group: usize,
pub max_bytes_per_file: usize,
pub mode: WriteMode,
pub store_params: Option<ObjectStoreParams>,
pub progress: Arc<dyn WriteFragmentProgress>,
pub commit_handler: Option<Arc<dyn CommitHandler>>,
pub data_storage_version: Option<LanceFileVersion>,
pub enable_stable_row_ids: bool,
pub enable_v2_manifest_paths: bool,
pub session: Option<Arc<Session>>,
pub auto_cleanup: Option<AutoCleanupParams>,
pub skip_auto_cleanup: bool,
pub transaction_properties: Option<Arc<HashMap<String, String>>>,
pub initial_bases: Option<Vec<BasePath>>,
pub target_bases: Option<Vec<u32>>,
pub target_base_names_or_paths: Option<Vec<String>>,
pub allow_external_blob_outside_bases: bool,
}
impl Default for WriteParams {
fn default() -> Self {
Self {
max_rows_per_file: 1024 * 1024, max_rows_per_group: 1024,
max_bytes_per_file: 90 * 1024 * 1024 * 1024, mode: WriteMode::Create,
store_params: None,
progress: Arc::new(NoopFragmentWriteProgress::new()),
commit_handler: None,
data_storage_version: None,
enable_stable_row_ids: false,
enable_v2_manifest_paths: true,
session: None,
auto_cleanup: Some(AutoCleanupParams::default()),
skip_auto_cleanup: false,
transaction_properties: None,
initial_bases: None,
target_bases: None,
target_base_names_or_paths: None,
allow_external_blob_outside_bases: false,
}
}
}
impl WriteParams {
pub fn with_storage_version(version: LanceFileVersion) -> Self {
Self {
data_storage_version: Some(version),
..Default::default()
}
}
pub fn storage_version_or_default(&self) -> LanceFileVersion {
self.data_storage_version.unwrap_or_default()
}
pub fn store_registry(&self) -> Arc<ObjectStoreRegistry> {
self.session
.as_ref()
.map(|s| s.store_registry())
.unwrap_or_default()
}
pub fn with_transaction_properties(self, properties: HashMap<String, String>) -> Self {
Self {
transaction_properties: Some(Arc::new(properties)),
..self
}
}
pub fn with_initial_bases(self, bases: Vec<BasePath>) -> Self {
Self {
initial_bases: Some(bases),
..self
}
}
pub fn with_target_bases(self, base_ids: Vec<u32>) -> Self {
Self {
target_bases: Some(base_ids),
..self
}
}
pub fn with_target_base_names_or_paths(self, references: Vec<String>) -> Self {
Self {
target_base_names_or_paths: Some(references),
..self
}
}
pub fn with_allow_external_blob_outside_bases(self, allow: bool) -> Self {
Self {
allow_external_blob_outside_bases: allow,
..self
}
}
}
#[deprecated(
since = "0.20.0",
note = "Use [`InsertBuilder::write_uncommitted_stream`] instead"
)]
pub async fn write_fragments(
dest: impl Into<WriteDestination<'_>>,
data: impl StreamingWriteSource,
params: WriteParams,
) -> Result<Transaction> {
InsertBuilder::new(dest.into())
.with_params(¶ms)
.execute_uncommitted_stream(data)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn do_write_fragments(
dataset: Option<&Dataset>,
object_store: Arc<ObjectStore>,
base_dir: &Path,
schema: &Schema,
data: SendableRecordBatchStream,
params: WriteParams,
storage_version: LanceFileVersion,
target_bases_info: Option<Vec<TargetBaseInfo>>,
) -> Result<Vec<Fragment>> {
let adapter = SchemaAdapter::new(data.schema());
let data = adapter.to_physical_stream(data);
let mut buffered_reader = if storage_version == LanceFileVersion::Legacy {
chunk_stream(data, params.max_rows_per_group)
} else {
break_stream(data, params.max_rows_per_file)
.map_ok(|batch| vec![batch])
.boxed()
};
let external_base_resolver = if storage_version >= LanceFileVersion::V2_2
&& schema.fields.iter().any(|field| field.is_blob_v2())
{
Some(Arc::new(
build_external_base_resolver(dataset, ¶ms).await?,
))
} else {
None
};
let writer_generator = WriterGenerator::new(
object_store,
base_dir,
schema,
storage_version,
target_bases_info,
external_base_resolver,
params.allow_external_blob_outside_bases,
);
let mut writer: Option<Box<dyn GenericWriter>> = None;
let mut num_rows_in_current_file = 0;
let mut fragments = Vec::new();
while let Some(batch_chunk) = buffered_reader.next().await {
let batch_chunk = batch_chunk?;
if writer.is_none() {
let (new_writer, new_fragment) = writer_generator.new_writer().await?;
params.progress.begin(&new_fragment).await?;
writer = Some(new_writer);
fragments.push(new_fragment);
}
writer.as_mut().unwrap().write(&batch_chunk).await?;
for batch in batch_chunk {
num_rows_in_current_file += batch.num_rows() as u32;
}
if num_rows_in_current_file >= params.max_rows_per_file as u32
|| writer.as_mut().unwrap().tell().await? >= params.max_bytes_per_file as u64
{
let (num_rows, data_file) = writer.take().unwrap().finish().await?;
info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_CREATE, r#type=AUDIT_TYPE_DATA, path = &data_file.path);
debug_assert_eq!(num_rows, num_rows_in_current_file);
params.progress.complete(fragments.last().unwrap()).await?;
let last_fragment = fragments.last_mut().unwrap();
last_fragment.physical_rows = Some(num_rows as usize);
last_fragment.files.push(data_file);
num_rows_in_current_file = 0;
}
}
if let Some(mut writer) = writer.take() {
let (num_rows, data_file) = writer.finish().await?;
info!(target: TRACE_FILE_AUDIT, mode=AUDIT_MODE_CREATE, r#type=AUDIT_TYPE_DATA, path = &data_file.path);
let last_fragment = fragments.last_mut().unwrap();
last_fragment.physical_rows = Some(num_rows as usize);
last_fragment.files.push(data_file);
}
Ok(fragments)
}
pub async fn validate_and_resolve_target_bases(
params: &mut WriteParams,
existing_base_paths: Option<&HashMap<u32, BasePath>>,
) -> Result<Option<Vec<TargetBaseInfo>>> {
if !matches!(params.mode, WriteMode::Create) && params.initial_bases.is_some() {
return Err(Error::invalid_input(format!(
"Cannot register new bases in {:?} mode. Only CREATE mode can register new bases.",
params.mode
)));
}
if params.target_base_names_or_paths.is_some() && params.target_bases.is_some() {
return Err(Error::invalid_input(
"Cannot specify both target_base_names_or_paths and target_bases. Use one or the other.",
));
}
let mut all_bases: HashMap<u32, BasePath> = existing_base_paths.cloned().unwrap_or_default();
if let Some(initial_bases) = &mut params.initial_bases {
let mut next_id = all_bases.keys().max().map(|&id| id + 1).unwrap_or(1);
for base_path in initial_bases.iter_mut() {
if base_path.id == 0 {
base_path.id = next_id;
next_id += 1;
}
all_bases.insert(base_path.id, base_path.clone());
}
}
let target_base_ids = if let Some(ref names_or_paths) = params.target_base_names_or_paths {
let mut resolved_ids = Vec::new();
for reference in names_or_paths {
let ref_str = reference.as_str();
let id = all_bases
.iter()
.find(|(_, base)| {
base.name.as_ref().map(|n| n == ref_str).unwrap_or(false)
|| base.path == ref_str
})
.map(|(&id, _)| id)
.ok_or_else(|| {
Error::invalid_input(format!(
"Base reference '{}' not found in available bases",
ref_str
))
})?;
resolved_ids.push(id);
}
Some(resolved_ids)
} else {
params.target_bases.clone()
};
let store_registry = params
.session
.as_ref()
.map(|s| s.store_registry())
.unwrap_or_default();
if let Some(target_bases) = &target_base_ids {
let store_params = params.store_params.clone().unwrap_or_default();
let mut bases_info = Vec::new();
for &target_base_id in target_bases {
let base_path = all_bases.get(&target_base_id).ok_or_else(|| {
Error::invalid_input(format!(
"Target base ID {} not found in available bases",
target_base_id
))
})?;
let (target_object_store, extracted_path) = ObjectStore::from_uri_and_params(
store_registry.clone(),
&base_path.path,
&store_params,
)
.await?;
bases_info.push(TargetBaseInfo {
base_id: target_base_id,
object_store: target_object_store,
base_dir: extracted_path,
is_dataset_root: base_path.is_dataset_root,
});
}
Ok(Some(bases_info))
} else {
Ok(None)
}
}
fn append_external_base_candidate(
base_path: &BasePath,
store_prefix: String,
extracted_path: Path,
candidates: &mut Vec<ExternalBaseCandidate>,
seen_base_ids: &mut HashSet<u32>,
) {
if base_path.is_dataset_root {
return;
}
if seen_base_ids.insert(base_path.id) {
candidates.push(ExternalBaseCandidate {
base_id: base_path.id,
store_prefix,
base_path: extracted_path,
});
}
}
async fn append_external_initial_bases(
initial_bases: Option<&Vec<BasePath>>,
store_registry: Arc<ObjectStoreRegistry>,
store_params: &ObjectStoreParams,
candidates: &mut Vec<ExternalBaseCandidate>,
seen_base_ids: &mut HashSet<u32>,
) -> Result<()> {
if let Some(initial_bases) = initial_bases {
for base_path in initial_bases {
let (store, extracted_path) = ObjectStore::from_uri_and_params(
store_registry.clone(),
&base_path.path,
store_params,
)
.await?;
append_external_base_candidate(
base_path,
store.store_prefix.clone(),
extracted_path,
candidates,
seen_base_ids,
);
}
}
Ok(())
}
async fn build_external_base_resolver(
dataset: Option<&Dataset>,
params: &WriteParams,
) -> Result<ExternalBaseResolver> {
let store_registry = dataset
.map(|ds| ds.session.store_registry())
.unwrap_or_else(|| params.store_registry());
let store_params = params.store_params.clone().unwrap_or_default();
let mut seen_base_ids = HashSet::new();
let mut candidates = vec![];
if let Some(dataset) = dataset {
for base_path in dataset.manifest.base_paths.values() {
let (store, extracted_path) = ObjectStore::from_uri_and_params(
store_registry.clone(),
&base_path.path,
&store_params,
)
.await?;
append_external_base_candidate(
base_path,
store.store_prefix.clone(),
extracted_path,
&mut candidates,
&mut seen_base_ids,
);
}
}
append_external_initial_bases(
params.initial_bases.as_ref(),
store_registry.clone(),
&store_params,
&mut candidates,
&mut seen_base_ids,
)
.await?;
Ok(ExternalBaseResolver::new(
candidates,
store_registry,
store_params,
))
}
#[instrument(level = "debug", skip_all)]
pub async fn write_fragments_internal(
dataset: Option<&Dataset>,
object_store: Arc<ObjectStore>,
base_dir: &Path,
schema: Schema,
data: SendableRecordBatchStream,
params: WriteParams,
target_bases_info: Option<Vec<TargetBaseInfo>>,
) -> Result<(Vec<Fragment>, Schema)> {
let mut params = params;
let adapter = SchemaAdapter::new(data.schema());
let (data, converted_schema) = if adapter.requires_physical_conversion() {
let data = adapter.to_physical_stream(data);
let arrow_schema = data.schema();
let converted_schema = Schema::try_from(arrow_schema.as_ref())?;
(data, converted_schema)
} else {
(data, schema)
};
params.max_rows_per_group = std::cmp::min(params.max_rows_per_group, params.max_rows_per_file);
let (schema, storage_version) = if let Some(dataset) = dataset {
match params.mode {
WriteMode::Append | WriteMode::Create => {
converted_schema.check_compatible(
dataset.schema(),
&SchemaCompareOptions {
compare_nullability: NullabilityComparison::Ignore,
allow_missing_if_nullable: true,
ignore_field_order: true,
compare_dictionary: dataset.is_legacy_storage(),
..Default::default()
},
)?;
let write_schema = dataset.schema().project_by_schema(
&converted_schema,
OnMissing::Error,
OnTypeMismatch::Error,
)?;
let data_storage_version = dataset
.manifest()
.data_storage_format
.lance_file_version()?;
(write_schema, data_storage_version)
}
WriteMode::Overwrite => {
let data_storage_version = params.data_storage_version.unwrap_or(
dataset
.manifest()
.data_storage_format
.lance_file_version()?,
);
(converted_schema, data_storage_version)
}
}
} else {
(converted_schema, params.storage_version_or_default())
};
if storage_version < LanceFileVersion::V2_2 && schema.fields.iter().any(|f| f.is_blob_v2()) {
return Err(Error::invalid_input(format!(
"Blob v2 requires file version >= 2.2 (got {:?})",
storage_version
)));
}
if storage_version >= LanceFileVersion::V2_2
&& schema
.fields
.iter()
.any(|f| f.metadata.contains_key(BLOB_META_KEY))
{
return Err(Error::invalid_input(format!(
"Legacy blob columns (field metadata key {BLOB_META_KEY:?}) are not supported for file version >= 2.2. Use the blob v2 extension type (ARROW:extension:name = \"lance.blob.v2\") and the new blob APIs (e.g. lance::blob::blob_field / lance::blob::BlobArrayBuilder)."
)));
}
let fragments = do_write_fragments(
dataset,
object_store,
base_dir,
&schema,
data,
params,
storage_version,
target_bases_info,
)
.await?;
Ok((fragments, schema))
}
#[async_trait::async_trait]
pub trait GenericWriter: Send {
async fn write(&mut self, batches: &[RecordBatch]) -> Result<()>;
async fn tell(&mut self) -> Result<u64>;
async fn finish(&mut self) -> Result<(u32, DataFile)>;
}
struct V1WriterAdapter<M>
where
M: PreviousManifestProvider + Send + Sync,
{
writer: PreviousFileWriter<M>,
path: String,
base_id: Option<u32>,
}
#[async_trait::async_trait]
impl<M> GenericWriter for V1WriterAdapter<M>
where
M: PreviousManifestProvider + Send + Sync,
{
async fn write(&mut self, batches: &[RecordBatch]) -> Result<()> {
self.writer.write(batches).await
}
async fn tell(&mut self) -> Result<u64> {
Ok(self.writer.tell().await? as u64)
}
async fn finish(&mut self) -> Result<(u32, DataFile)> {
let size_bytes = self.writer.tell().await?;
Ok((
self.writer.finish().await? as u32,
DataFile::new_legacy(
self.path.clone(),
self.writer.schema(),
NonZero::new(size_bytes as u64),
self.base_id,
),
))
}
}
struct V2WriterAdapter {
writer: current_writer::FileWriter,
path: String,
base_id: Option<u32>,
preprocessor: Option<BlobPreprocessor>,
}
#[async_trait::async_trait]
impl GenericWriter for V2WriterAdapter {
async fn write(&mut self, batches: &[RecordBatch]) -> Result<()> {
if let Some(pre) = self.preprocessor.as_mut() {
let processed = preprocess_blob_batches(batches, pre).await?;
for batch in processed {
self.writer.write_batch(&batch).await?;
}
} else {
for batch in batches {
self.writer.write_batch(batch).await?;
}
}
Ok(())
}
async fn tell(&mut self) -> Result<u64> {
Ok(self.writer.tell().await?)
}
async fn finish(&mut self) -> Result<(u32, DataFile)> {
if let Some(pre) = self.preprocessor.as_mut() {
pre.finish().await?;
}
let field_ids = self
.writer
.field_id_to_column_indices()
.iter()
.map(|(field_id, _)| *field_id as i32)
.collect::<Vec<_>>();
let column_indices = self
.writer
.field_id_to_column_indices()
.iter()
.map(|(_, column_index)| *column_index as i32)
.collect::<Vec<_>>();
let (major, minor) = self.writer.version().to_numbers();
let num_rows = self.writer.finish().await? as u32;
let data_file = DataFile::new(
std::mem::take(&mut self.path),
field_ids,
column_indices,
major,
minor,
NonZero::new(self.writer.tell().await?),
self.base_id,
);
Ok((num_rows, data_file))
}
}
pub async fn open_writer(
object_store: &ObjectStore,
schema: &Schema,
base_dir: &Path,
storage_version: LanceFileVersion,
) -> Result<Box<dyn GenericWriter>> {
open_writer_with_options(
object_store,
schema,
base_dir,
storage_version,
WriterOptions {
add_data_dir: true,
..Default::default()
},
)
.await
}
#[derive(Default)]
struct WriterOptions {
add_data_dir: bool,
base_id: Option<u32>,
external_base_resolver: Option<Arc<ExternalBaseResolver>>,
allow_external_blob_outside_bases: bool,
}
async fn open_writer_with_options(
object_store: &ObjectStore,
schema: &Schema,
base_dir: &Path,
storage_version: LanceFileVersion,
options: WriterOptions,
) -> Result<Box<dyn GenericWriter>> {
let WriterOptions {
add_data_dir,
base_id,
external_base_resolver,
allow_external_blob_outside_bases,
} = options;
let data_file_key = generate_random_filename();
let filename = format!("{}.lance", data_file_key);
let data_dir = if add_data_dir {
base_dir.child(DATA_DIR)
} else {
base_dir.clone()
};
let full_path = data_dir.child(filename.as_str());
let writer = if storage_version == LanceFileVersion::Legacy {
Box::new(V1WriterAdapter {
writer: PreviousFileWriter::<ManifestDescribing>::try_new(
object_store,
&full_path,
schema.clone(),
&Default::default(),
)
.await?,
path: filename,
base_id,
})
} else {
let writer = object_store.create(&full_path).await?;
let enable_blob_v2 = storage_version >= LanceFileVersion::V2_2;
let file_writer = current_writer::FileWriter::try_new(
writer,
schema.clone(),
FileWriterOptions {
format_version: Some(storage_version),
..Default::default()
},
)?;
let preprocessor = if enable_blob_v2 {
Some(BlobPreprocessor::new(
object_store.clone(),
data_dir.clone(),
data_file_key.clone(),
schema,
external_base_resolver,
allow_external_blob_outside_bases,
))
} else {
None
};
let writer_adapter = V2WriterAdapter {
writer: file_writer,
path: filename,
base_id,
preprocessor,
};
Box::new(writer_adapter) as Box<dyn GenericWriter>
};
Ok(writer)
}
pub struct TargetBaseInfo {
pub base_id: u32,
pub object_store: Arc<ObjectStore>,
pub base_dir: Path,
pub is_dataset_root: bool,
}
struct WriterGenerator {
object_store: Arc<ObjectStore>,
base_dir: Path,
schema: Schema,
storage_version: LanceFileVersion,
target_bases_info: Option<Vec<TargetBaseInfo>>,
external_base_resolver: Option<Arc<ExternalBaseResolver>>,
allow_external_blob_outside_bases: bool,
next_base_index: AtomicUsize,
}
impl WriterGenerator {
pub fn new(
object_store: Arc<ObjectStore>,
base_dir: &Path,
schema: &Schema,
storage_version: LanceFileVersion,
target_bases_info: Option<Vec<TargetBaseInfo>>,
external_base_resolver: Option<Arc<ExternalBaseResolver>>,
allow_external_blob_outside_bases: bool,
) -> Self {
Self {
object_store,
base_dir: base_dir.clone(),
schema: schema.clone(),
storage_version,
target_bases_info,
external_base_resolver,
allow_external_blob_outside_bases,
next_base_index: AtomicUsize::new(0),
}
}
fn select_target_base(&self) -> Option<&TargetBaseInfo> {
self.target_bases_info.as_ref().map(|bases| {
let index = self
.next_base_index
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
&bases[index % bases.len()]
})
}
pub async fn new_writer(&self) -> Result<(Box<dyn GenericWriter>, Fragment)> {
let fragment = Fragment::new(0);
let writer = if let Some(base_info) = self.select_target_base() {
open_writer_with_options(
&base_info.object_store,
&self.schema,
&base_info.base_dir,
self.storage_version,
WriterOptions {
add_data_dir: base_info.is_dataset_root,
base_id: Some(base_info.base_id),
external_base_resolver: self.external_base_resolver.clone(),
allow_external_blob_outside_bases: self.allow_external_blob_outside_bases,
},
)
.await?
} else {
open_writer_with_options(
&self.object_store,
&self.schema,
&self.base_dir,
self.storage_version,
WriterOptions {
add_data_dir: true,
base_id: None,
external_base_resolver: self.external_base_resolver.clone(),
allow_external_blob_outside_bases: self.allow_external_blob_outside_bases,
},
)
.await?
};
Ok((writer, fragment))
}
}
async fn resolve_commit_handler(
uri: &str,
commit_handler: Option<Arc<dyn CommitHandler>>,
store_options: &Option<ObjectStoreParams>,
) -> Result<Arc<dyn CommitHandler>> {
match commit_handler {
None => {
#[allow(deprecated)]
if store_options
.as_ref()
.map(|opts| opts.object_store.is_some())
.unwrap_or_default()
{
return Err(Error::invalid_input(
"when creating a dataset with a custom object store the commit_handler must also be specified",
));
}
commit_handler_from_url(uri, store_options).await
}
Some(commit_handler) => {
if uri.starts_with("s3+ddb") {
Err(Error::invalid_input(
"`s3+ddb://` scheme and custom commit handler are mutually exclusive",
))
} else {
Ok(commit_handler)
}
}
}
}
async fn new_source_iter(
source: SendableRecordBatchStream,
enable_retries: bool,
) -> Result<Box<dyn Iterator<Item = SendableRecordBatchStream> + Send + 'static>> {
if enable_retries {
let schema = source.schema();
let size_hint = source.size_hint();
if size_hint.0 == 1 && size_hint.1 == Some(1) {
let batches: Vec<RecordBatch> = source.try_collect().await?;
Ok(Box::new(std::iter::repeat_with(move || {
Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(batches.clone().into_iter().map(Ok)),
)) as SendableRecordBatchStream
})))
} else {
Ok(Box::new(
SpillStreamIter::try_new(source, 100 * 1024 * 1024).await?,
))
}
} else {
Ok(Box::new(std::iter::once(source)))
}
}
struct SpillStreamIter {
receiver: SpillReceiver,
#[allow(dead_code)] sender_handle: tokio::task::JoinHandle<SpillSender>,
#[allow(dead_code)] tmp_dir: TempDir,
}
impl SpillStreamIter {
pub async fn try_new(
mut source: SendableRecordBatchStream,
memory_limit: usize,
) -> Result<Self> {
let tmp_dir = tokio::task::spawn_blocking(|| {
TempDir::try_new()
.map_err(|e| Error::invalid_input(format!("Failed to create temp dir: {}", e)))
})
.await
.ok()
.expect_ok()??;
let tmp_path = tmp_dir.std_path().join("spill.arrows");
let (mut sender, receiver) = create_replay_spill(tmp_path, source.schema(), memory_limit);
let sender_handle = tokio::task::spawn(async move {
while let Some(res) = source.next().await {
match res {
Ok(batch) => match sender.write(batch).await {
Ok(_) => {}
Err(e) => {
sender.send_error(e);
break;
}
},
Err(e) => {
sender.send_error(e);
break;
}
}
}
if let Err(err) = sender.finish().await {
sender.send_error(err);
}
sender
});
Ok(Self {
receiver,
tmp_dir,
sender_handle,
})
}
}
impl Iterator for SpillStreamIter {
type Item = SendableRecordBatchStream;
fn next(&mut self) -> Option<Self::Item> {
Some(self.receiver.read())
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, RecordBatchIterator, RecordBatchReader, StructArray};
use arrow_schema::{DataType, Field as ArrowField, Fields, Schema as ArrowSchema};
use datafusion::{error::DataFusionError, physical_plan::stream::RecordBatchStreamAdapter};
use datafusion_physical_plan::RecordBatchStream;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, RowCount, array, gen_batch};
use lance_file::previous::reader::FileReader as PreviousFileReader;
use lance_io::traits::Reader;
#[tokio::test]
async fn test_chunking_large_batches() {
let schema = Arc::new(ArrowSchema::new(vec![arrow::datatypes::Field::new(
"a",
DataType::Int32,
false,
)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from_iter(0..28))])
.unwrap();
let batches: Vec<RecordBatch> =
vec![batch.slice(0, 10), batch.slice(10, 10), batch.slice(20, 8)];
let stream = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(batches.into_iter().map(Ok::<_, DataFusionError>)),
);
let chunks: Vec<Vec<RecordBatch>> = chunk_stream(Box::pin(stream), 3)
.try_collect()
.await
.unwrap();
assert_eq!(chunks.len(), 10);
assert_eq!(chunks[0].len(), 1);
for (i, chunk) in chunks.iter().enumerate() {
let num_rows = chunk.iter().map(|batch| batch.num_rows()).sum::<usize>();
if i < chunks.len() - 1 {
assert_eq!(num_rows, 3);
} else {
assert_eq!(num_rows, 1);
}
}
assert_eq!(chunks[3].len(), 2);
assert_eq!(chunks[3][0].num_rows(), 1);
assert_eq!(chunks[3][1].num_rows(), 2);
}
#[tokio::test]
async fn test_chunking_small_batches() {
let schema = Arc::new(ArrowSchema::new(vec![arrow::datatypes::Field::new(
"a",
DataType::Int32,
false,
)]));
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from_iter(0..30))])
.unwrap();
let batches: Vec<RecordBatch> = (0..10).map(|i| batch.slice(i * 3, 3)).collect();
let stream = RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(batches.into_iter().map(Ok::<_, DataFusionError>)),
);
let chunks: Vec<Vec<RecordBatch>> = chunk_stream(Box::pin(stream), 10)
.try_collect()
.await
.unwrap();
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].len(), 4);
assert_eq!(chunks[0][0], batch.slice(0, 3));
assert_eq!(chunks[0][1], batch.slice(3, 3));
assert_eq!(chunks[0][2], batch.slice(6, 3));
assert_eq!(chunks[0][3], batch.slice(9, 1));
for chunk in &chunks {
let num_rows = chunk.iter().map(|batch| batch.num_rows()).sum::<usize>();
assert_eq!(num_rows, 10);
}
}
#[tokio::test]
async fn test_file_size() {
let reader_to_frags = |data_reader: Box<dyn RecordBatchReader + Send>| {
let schema = data_reader.schema();
let data_reader =
data_reader.map(|rb| rb.map_err(datafusion::error::DataFusionError::from));
let data_stream = Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(data_reader),
));
let write_params = WriteParams {
max_rows_per_file: 1024 * 1024, max_bytes_per_file: 2 * 1024,
mode: WriteMode::Create,
..Default::default()
};
async move {
let schema = Schema::try_from(schema.as_ref()).unwrap();
let object_store = Arc::new(ObjectStore::memory());
write_fragments_internal(
None,
object_store,
&Path::from("test"),
schema,
data_stream,
write_params,
None,
)
.await
}
};
let data_reader = Box::new(
gen_batch()
.anon_col(array::rand_fsb(1024))
.into_reader_rows(RowCount::from(10 * 1024), BatchCount::from(2)),
);
let (fragments, _) = reader_to_frags(data_reader).await.unwrap();
assert_eq!(fragments.len(), 2);
}
#[tokio::test]
async fn test_max_rows_per_file() {
let reader_to_frags = |data_reader: Box<dyn RecordBatchReader + Send>| {
let schema = data_reader.schema();
let data_reader =
data_reader.map(|rb| rb.map_err(datafusion::error::DataFusionError::from));
let data_stream = Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(data_reader),
));
let write_params = WriteParams {
max_rows_per_file: 5000, max_bytes_per_file: 1024 * 1024 * 1024, mode: WriteMode::Create,
..Default::default()
};
async move {
let schema = Schema::try_from(schema.as_ref()).unwrap();
let object_store = Arc::new(ObjectStore::memory());
write_fragments_internal(
None,
object_store,
&Path::from("test"),
schema,
data_stream,
write_params,
None,
)
.await
}
};
let data_reader = Box::new(
gen_batch()
.anon_col(array::rand_type(&DataType::Int32))
.into_reader_rows(RowCount::from(12000), BatchCount::from(1)),
);
let (fragments, _) = reader_to_frags(data_reader).await.unwrap();
assert_eq!(fragments.len(), 3);
let row_counts: Vec<usize> = fragments
.iter()
.map(|f| f.physical_rows.unwrap_or(0))
.collect();
assert_eq!(row_counts, vec![5000, 5000, 2000]);
}
#[tokio::test]
async fn test_max_rows_per_group() {
let reader_to_frags = |data_reader: Box<dyn RecordBatchReader + Send>,
version: LanceFileVersion| {
let schema = data_reader.schema();
let data_reader =
data_reader.map(|rb| rb.map_err(datafusion::error::DataFusionError::from));
let data_stream = Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(data_reader),
));
let write_params = WriteParams {
max_rows_per_file: 5000, max_rows_per_group: 3000, mode: WriteMode::Create,
data_storage_version: Some(version),
..Default::default()
};
async move {
let schema = Schema::try_from(schema.as_ref()).unwrap();
let object_store = Arc::new(ObjectStore::memory());
write_fragments_internal(
None,
object_store,
&Path::from("test"),
schema,
data_stream,
write_params,
None,
)
.await
}
};
let data_reader_v1 = Box::new(
gen_batch()
.anon_col(array::rand_type(&DataType::Int32))
.into_reader_rows(RowCount::from(9000), BatchCount::from(1)),
);
let (fragments_v1, _) = reader_to_frags(data_reader_v1, LanceFileVersion::Legacy)
.await
.unwrap();
let row_counts_v1: Vec<usize> = fragments_v1
.iter()
.map(|f| f.physical_rows.unwrap_or(0))
.collect();
assert_eq!(fragments_v1.len(), 2);
assert_eq!(row_counts_v1, vec![6000, 3000]);
let data_reader_v2 = Box::new(
gen_batch()
.anon_col(array::rand_type(&DataType::Int32))
.into_reader_rows(RowCount::from(9000), BatchCount::from(1)),
);
let (fragments_v2, _) = reader_to_frags(data_reader_v2, LanceFileVersion::Stable)
.await
.unwrap();
let row_counts_v2: Vec<usize> = fragments_v2
.iter()
.map(|f| f.physical_rows.unwrap_or(0))
.collect();
assert_eq!(fragments_v2.len(), 2);
assert_eq!(row_counts_v2, vec![5000, 4000]);
assert_eq!(fragments_v1.len(), fragments_v2.len());
assert_ne!(row_counts_v1, row_counts_v2);
}
#[tokio::test]
async fn test_file_write_version() {
let schema = Arc::new(ArrowSchema::new(vec![arrow::datatypes::Field::new(
"a",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter(0..1024))],
)
.unwrap();
let versions = vec![
LanceFileVersion::Legacy,
LanceFileVersion::V2_0,
LanceFileVersion::V2_1,
LanceFileVersion::V2_2,
LanceFileVersion::Stable,
LanceFileVersion::Next,
];
for version in versions {
let (major, minor) = version.to_numbers();
let write_params = WriteParams {
data_storage_version: Some(version),
max_rows_per_group: 1,
..Default::default()
};
let data_stream = Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::iter(std::iter::once(Ok(batch.clone()))),
));
let schema = Schema::try_from(schema.as_ref()).unwrap();
let object_store = Arc::new(ObjectStore::memory());
let (fragments, _) = write_fragments_internal(
None,
object_store,
&Path::from("test"),
schema,
data_stream,
write_params,
None,
)
.await
.unwrap();
assert_eq!(fragments.len(), 1);
let fragment = &fragments[0];
assert_eq!(fragment.files.len(), 1);
assert_eq!(fragment.physical_rows, Some(1024));
assert_eq!(
fragment.files[0].file_major_version, major,
"version: {}",
version
);
assert_eq!(
fragment.files[0].file_minor_version, minor,
"version: {}",
version
);
}
}
#[tokio::test]
async fn test_file_v1_schema_order() {
let struct_fields = Fields::from(vec![ArrowField::new("b", DataType::Int32, false)]);
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("d", DataType::Int32, false),
ArrowField::new("a", DataType::Struct(struct_fields.clone()), false),
]);
let mut schema = Schema::try_from(&arrow_schema).unwrap();
schema.mut_field_by_id(0).unwrap().id = 3;
schema.mut_field_by_id(1).unwrap().id = 0;
schema.mut_field_by_id(2).unwrap().id = 1;
let field_ids = schema.fields_pre_order().map(|f| f.id).collect::<Vec<_>>();
assert_eq!(field_ids, vec![3, 0, 1]);
let data = RecordBatch::try_new(
Arc::new(arrow_schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StructArray::new(
struct_fields,
vec![Arc::new(Int32Array::from(vec![3, 4]))],
None,
)),
],
)
.unwrap();
let write_params = WriteParams {
data_storage_version: Some(LanceFileVersion::Legacy),
..Default::default()
};
let data_stream = Box::pin(RecordBatchStreamAdapter::new(
Arc::new(arrow_schema),
futures::stream::iter(std::iter::once(Ok(data.clone()))),
));
let object_store = Arc::new(ObjectStore::memory());
let base_path = Path::from("test");
let (fragments, _) = write_fragments_internal(
None,
object_store.clone(),
&base_path,
schema.clone(),
data_stream,
write_params,
None,
)
.await
.unwrap();
assert_eq!(fragments.len(), 1);
let fragment = &fragments[0];
assert_eq!(fragment.files.len(), 1);
assert_eq!(fragment.files[0].fields, vec![0, 1, 3]);
let path = base_path
.child(DATA_DIR)
.child(fragment.files[0].path.as_str());
let file_reader: Arc<dyn Reader> = object_store.open(&path).await.unwrap().into();
let reader = PreviousFileReader::try_new_from_reader(
&path,
file_reader,
None,
schema.clone(),
0,
0,
3,
None,
)
.await
.unwrap();
assert_eq!(reader.num_batches(), 1);
let batch = reader.read_batch(0, .., &schema).await.unwrap();
assert_eq!(batch, data);
}
#[tokio::test]
async fn test_explicit_data_file_bases_writer_generator() {
use arrow::datatypes::{DataType, Field as ArrowField, Schema as ArrowSchema};
use lance_io::object_store::ObjectStore;
use std::sync::Arc;
let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
let object_store = Arc::new(ObjectStore::memory());
let base_dir = Path::from("test/bucket2");
let target_bases = vec![TargetBaseInfo {
base_id: 2,
object_store: object_store.clone(),
base_dir: base_dir.clone(),
is_dataset_root: false, }];
let writer_generator = WriterGenerator::new(
object_store.clone(),
&base_dir,
&schema,
LanceFileVersion::Stable,
Some(target_bases),
None,
false,
);
let (writer, fragment) = writer_generator.new_writer().await.unwrap();
assert_eq!(fragment.id, 0);
drop(writer); }
#[tokio::test]
async fn test_writer_with_base_id() {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow::record_batch::RecordBatch;
use lance_io::object_store::ObjectStore;
use std::sync::Arc;
let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
let batch = RecordBatch::try_new(
arrow_schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let object_store = Arc::new(ObjectStore::memory());
let base_dir = Path::from("test/bucket2");
let mut inner_writer = open_writer_with_options(
&object_store,
&schema,
&base_dir,
LanceFileVersion::Stable,
WriterOptions {
add_data_dir: false, ..Default::default()
},
)
.await
.unwrap();
inner_writer.write(&[batch]).await.unwrap();
let base_id = 2u32;
let (_num_rows, mut data_file) = inner_writer.finish().await.unwrap();
data_file.base_id = Some(base_id);
assert_eq!(data_file.base_id, Some(base_id));
assert!(!data_file.path.is_empty());
}
#[tokio::test]
async fn test_round_robin_target_base_selection() {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow::record_batch::RecordBatch;
use lance_io::object_store::ObjectStore;
use std::sync::Arc;
let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
let store1 = Arc::new(ObjectStore::memory());
let store2 = Arc::new(ObjectStore::memory());
let store3 = Arc::new(ObjectStore::memory());
let target_bases = vec![
TargetBaseInfo {
base_id: 1,
object_store: store1.clone(),
base_dir: Path::from("base1"),
is_dataset_root: false,
},
TargetBaseInfo {
base_id: 2,
object_store: store2.clone(),
base_dir: Path::from("base2"),
is_dataset_root: false,
},
TargetBaseInfo {
base_id: 3,
object_store: store3.clone(),
base_dir: Path::from("base3"),
is_dataset_root: false,
},
];
let writer_generator = WriterGenerator::new(
Arc::new(ObjectStore::memory()),
&Path::from("default"),
&schema,
LanceFileVersion::Stable,
Some(target_bases),
None,
false,
);
let batch = RecordBatch::try_new(
arrow_schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let mut base_ids = Vec::new();
for _ in 0..6 {
let (mut writer, _fragment) = writer_generator.new_writer().await.unwrap();
writer.write(std::slice::from_ref(&batch)).await.unwrap();
let (_num_rows, data_file) = writer.finish().await.unwrap();
base_ids.push(data_file.base_id.unwrap());
}
assert_eq!(base_ids, vec![1, 2, 3, 1, 2, 3]);
}
#[tokio::test]
async fn test_explicit_data_file_bases_path_parsing() {
let test_cases = vec![
("s3://multi-path-test/test1/subBucket2", "test1/subBucket2"),
("gs://my-bucket/path/to/data", "path/to/data"),
("az://container/path/to/data", "path/to/data"),
(
"abfss://filesystem@account.dfs.core.windows.net/path/to/data",
"path/to/data",
),
("file:///tmp/test/bucket", "tmp/test/bucket"),
];
for (uri, expected_path) in test_cases {
let url = url::Url::parse(uri).unwrap();
let parsed_path = url.path().trim_start_matches('/');
assert_eq!(parsed_path, expected_path, "Failed for URI: {}", uri);
}
}
#[tokio::test]
async fn test_write_params_validation() {
let mut params = WriteParams {
mode: WriteMode::Create,
initial_bases: Some(vec![
BasePath {
id: 1,
name: Some("bucket1".to_string()),
path: "s3://bucket1/path1".to_string(),
is_dataset_root: true,
},
BasePath {
id: 2,
name: Some("bucket2".to_string()),
path: "s3://bucket2/path2".to_string(),
is_dataset_root: true,
},
BasePath {
id: 3,
name: Some("azure-az-base".to_string()),
path: "az://container/path1".to_string(),
is_dataset_root: true,
},
BasePath {
id: 4,
name: Some("azure-abfss-base".to_string()),
path: "abfss://filesystem@account.dfs.core.windows.net/path1".to_string(),
is_dataset_root: true,
},
]),
target_bases: Some(vec![1]), ..Default::default()
};
let result = validate_write_params(¶ms);
assert!(result.is_ok());
params.target_bases = Some(vec![99]); let result = validate_write_params(¶ms);
assert!(result.is_err());
params.initial_bases = None;
params.target_bases = Some(vec![1]);
let result = validate_write_params(¶ms);
assert!(result.is_err());
}
fn validate_write_params(params: &WriteParams) -> Result<()> {
if matches!(params.mode, WriteMode::Create)
&& let Some(target_bases) = ¶ms.target_bases
{
if target_bases.len() != 1 {
return Err(Error::invalid_input(format!(
"target_bases with {} elements is not supported",
target_bases.len()
)));
}
let target_base_id = target_bases[0];
if let Some(initial_bases) = ¶ms.initial_bases {
if !initial_bases.iter().any(|bp| bp.id == target_base_id) {
return Err(Error::invalid_input(format!(
"target_base_id {} must be one of the initial_bases in CREATE mode",
target_base_id
)));
}
} else {
return Err(Error::invalid_input(
"initial_bases must be provided when target_bases is specified in CREATE mode",
));
}
}
Ok(())
}
#[tokio::test]
async fn test_multi_base_create() {
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
let test_uri = "memory://multi_base_test";
let primary_uri = format!("{}/primary", test_uri);
let base1_uri = format!("{}/base1", test_uri);
let base2_uri = format!("{}/base2", test_uri);
let mut data_gen =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = crate::dataset::Dataset::write(
data_gen.batch(5),
&primary_uri,
Some(WriteParams {
mode: WriteMode::Create,
initial_bases: Some(vec![
BasePath {
id: 1,
name: Some("base1".to_string()),
path: base1_uri.clone(),
is_dataset_root: true,
},
BasePath {
id: 2,
name: Some("base2".to_string()),
path: base2_uri.clone(),
is_dataset_root: true,
},
]),
target_bases: Some(vec![1]),
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 5);
assert_eq!(dataset.manifest.base_paths.len(), 2);
assert!(
dataset
.manifest
.base_paths
.values()
.any(|bp| bp.name == Some("base1".to_string()))
);
assert!(
dataset
.manifest
.base_paths
.values()
.any(|bp| bp.name == Some("base2".to_string()))
);
let fragments = dataset.get_fragments();
assert!(!fragments.is_empty());
for fragment in fragments {
assert!(
fragment
.metadata
.files
.iter()
.any(|file| file.base_id == Some(1))
);
}
let mut data_gen2 =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let result = Dataset::write(
data_gen2.batch(5),
&format!("{}/test_validation", test_uri),
Some(WriteParams {
mode: WriteMode::Create,
initial_bases: Some(vec![BasePath {
id: 1,
name: Some("base1".to_string()),
path: base1_uri.clone(),
is_dataset_root: true,
}]),
target_bases: Some(vec![1]),
target_base_names_or_paths: Some(vec!["base1".to_string()]),
..Default::default()
}),
)
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Cannot specify both target_base_names_or_paths and target_bases")
);
}
#[tokio::test]
async fn test_multi_base_overwrite() {
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
let test_uri = "memory://overwrite_test";
let primary_uri = format!("{}/primary", test_uri);
let base1_uri = format!("{}/base1", test_uri);
let base2_uri = format!("{}/base2", test_uri);
let _base3_uri = format!("{}/base3", test_uri);
let mut data_gen =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen.batch(3),
&primary_uri,
Some(WriteParams {
mode: WriteMode::Create,
initial_bases: Some(vec![
BasePath {
id: 1,
name: Some("base1".to_string()),
path: base1_uri.clone(),
is_dataset_root: true,
},
BasePath {
id: 2,
name: Some("base2".to_string()),
path: base2_uri.clone(),
is_dataset_root: true,
},
]),
target_bases: Some(vec![1]),
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 3);
let mut data_gen2 =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen2.batch(2),
std::sync::Arc::new(dataset),
Some(WriteParams {
mode: WriteMode::Overwrite,
target_bases: Some(vec![2]), ..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 2);
assert_eq!(dataset.manifest.base_paths.len(), 2);
assert!(
dataset
.manifest
.base_paths
.values()
.any(|bp| bp.name == Some("base1".to_string()))
);
assert!(
dataset
.manifest
.base_paths
.values()
.any(|bp| bp.name == Some("base2".to_string()))
);
let fragments = dataset.get_fragments();
assert!(
fragments
.iter()
.all(|f| f.metadata.files.iter().all(|file| file.base_id == Some(2)))
);
let mut data_gen3 =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let result = Dataset::write(
data_gen3.batch(2),
Arc::new(dataset),
Some(WriteParams {
mode: WriteMode::Overwrite,
initial_bases: Some(vec![BasePath {
id: 3,
name: Some("base3".to_string()),
path: _base3_uri.clone(),
is_dataset_root: true,
}]),
target_bases: Some(vec![1]),
..Default::default()
}),
)
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Cannot register new bases in Overwrite mode")
);
}
#[tokio::test]
async fn test_multi_base_append() {
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
let test_uri = "memory://append_test";
let primary_uri = format!("{}/primary", test_uri);
let base1_uri = format!("{}/base1", test_uri);
let base2_uri = format!("{}/base2", test_uri);
let mut data_gen =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen.batch(3),
&primary_uri,
Some(WriteParams {
mode: WriteMode::Create,
initial_bases: Some(vec![
BasePath {
id: 1,
name: Some("base1".to_string()),
path: base1_uri.clone(),
is_dataset_root: true,
},
BasePath {
id: 2,
name: Some("base2".to_string()),
path: base2_uri.clone(),
is_dataset_root: true,
},
]),
target_bases: Some(vec![1]),
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 3);
let mut data_gen2 =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen2.batch(2),
std::sync::Arc::new(dataset),
Some(WriteParams {
mode: WriteMode::Append,
target_bases: Some(vec![1]),
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 5);
assert_eq!(dataset.manifest.base_paths.len(), 2);
let mut data_gen3 =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen3.batch(4),
Arc::new(dataset),
Some(WriteParams {
mode: WriteMode::Append,
target_bases: Some(vec![2]),
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 9);
let fragments = dataset.get_fragments();
let has_base1_data = fragments
.iter()
.any(|f| f.metadata.files.iter().any(|file| file.base_id == Some(1)));
let has_base2_data = fragments
.iter()
.any(|f| f.metadata.files.iter().any(|file| file.base_id == Some(2)));
assert!(has_base1_data, "Should have data in base1");
assert!(has_base2_data, "Should have data in base2");
let mut data_gen4 =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let base3_uri = format!("{}/base3", test_uri);
let result = Dataset::write(
data_gen4.batch(2),
Arc::new(dataset),
Some(WriteParams {
mode: WriteMode::Append,
initial_bases: Some(vec![BasePath {
id: 3,
name: Some("base3".to_string()),
path: base3_uri,
is_dataset_root: true,
}]),
target_bases: Some(vec![1]),
..Default::default()
}),
)
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Cannot register new bases in Append mode")
);
}
#[tokio::test]
async fn test_multi_base_is_dataset_root_flag() {
use lance_core::utils::tempfile::TempDir;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
let test_dir = TempDir::default();
let primary_uri = test_dir.path_str();
let base1_dir = test_dir.std_path().join("base1");
let base2_dir = test_dir.std_path().join("base2");
std::fs::create_dir_all(&base1_dir).unwrap();
std::fs::create_dir_all(&base2_dir).unwrap();
let base1_uri = format!("file://{}", base1_dir.display());
let base2_uri = format!("file://{}", base2_dir.display());
let mut data_gen =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen.batch(10),
&primary_uri,
Some(WriteParams {
mode: WriteMode::Create,
max_rows_per_file: 5, initial_bases: Some(vec![
BasePath {
id: 1,
name: Some("base1".to_string()),
path: base1_uri.clone(),
is_dataset_root: true, },
BasePath {
id: 2,
name: Some("base2".to_string()),
path: base2_uri.clone(),
is_dataset_root: false, },
]),
target_bases: Some(vec![1, 2]), ..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 10);
assert_eq!(dataset.manifest.base_paths.len(), 2);
let base1 = dataset
.manifest
.base_paths
.values()
.find(|bp| bp.name == Some("base1".to_string()))
.expect("base1 not found");
let base2 = dataset
.manifest
.base_paths
.values()
.find(|bp| bp.name == Some("base2".to_string()))
.expect("base2 not found");
assert!(
base1.is_dataset_root,
"base1 should have is_dataset_root=true"
);
assert!(
!base2.is_dataset_root,
"base2 should have is_dataset_root=false"
);
let fragments = dataset.get_fragments();
assert!(!fragments.is_empty());
let has_base1_data = fragments
.iter()
.any(|f| f.metadata.files.iter().any(|file| file.base_id == Some(1)));
let has_base2_data = fragments
.iter()
.any(|f| f.metadata.files.iter().any(|file| file.base_id == Some(2)));
assert!(has_base1_data, "Should have data in base1");
assert!(has_base2_data, "Should have data in base2");
let base1_data_dir = base1_dir.join("data");
assert!(base1_data_dir.exists(), "base1/data directory should exist");
let base1_files: Vec<_> = std::fs::read_dir(&base1_data_dir)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.map(|ext| ext == "lance")
.unwrap_or(false)
})
.collect();
assert!(
!base1_files.is_empty(),
"base1/data should contain .lance files"
);
let base2_files: Vec<_> = std::fs::read_dir(&base2_dir)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.map(|ext| ext == "lance")
.unwrap_or(false)
})
.collect();
assert!(
!base2_files.is_empty(),
"base2 should contain .lance files directly"
);
let base2_data_dir = base2_dir.join("data");
if base2_data_dir.exists() {
let base2_data_files: Vec<_> = std::fs::read_dir(&base2_data_dir)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.map(|ext| ext == "lance")
.unwrap_or(false)
})
.collect();
assert!(
base2_data_files.is_empty(),
"base2/data should NOT contain .lance files"
);
}
}
#[tokio::test]
async fn test_multi_base_target_by_path_uri() {
use lance_core::utils::tempfile::TempDir;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
let test_dir = TempDir::default();
let primary_uri = test_dir.path_str();
let base1_dir = test_dir.std_path().join("base1");
let base2_dir = test_dir.std_path().join("base2");
std::fs::create_dir_all(&base1_dir).unwrap();
std::fs::create_dir_all(&base2_dir).unwrap();
let base1_uri = format!("file://{}", base1_dir.display());
let base2_uri = format!("file://{}", base2_dir.display());
let mut data_gen =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen.batch(10),
&primary_uri,
Some(WriteParams {
mode: WriteMode::Create,
max_rows_per_file: 5,
initial_bases: Some(vec![
BasePath {
id: 1,
name: Some("base1".to_string()),
path: base1_uri.clone(),
is_dataset_root: true,
},
BasePath {
id: 2,
name: Some("base2".to_string()),
path: base2_uri.clone(),
is_dataset_root: true,
},
]),
target_base_names_or_paths: Some(vec!["base1".to_string()]), ..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 10);
let fragments = dataset.get_fragments();
assert!(
fragments
.iter()
.all(|f| f.metadata.files.iter().all(|file| file.base_id == Some(1)))
);
let mut data_gen2 =
BatchGenerator::new().col(Box::new(IncrementingInt32::new().named("id".to_owned())));
let dataset = Dataset::write(
data_gen2.batch(5),
Arc::new(dataset),
Some(WriteParams {
mode: WriteMode::Append,
target_base_names_or_paths: Some(vec![base2_uri.clone()]),
max_rows_per_file: 5,
..Default::default()
}),
)
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 15);
let fragments = dataset.get_fragments();
let has_base1_data = fragments
.iter()
.any(|f| f.metadata.files.iter().any(|file| file.base_id == Some(1)));
let has_base2_data = fragments
.iter()
.any(|f| f.metadata.files.iter().any(|file| file.base_id == Some(2)));
assert!(has_base1_data, "Should have data in base1");
assert!(has_base2_data, "Should have data in base2");
let base2_fragments: Vec<_> = fragments
.iter()
.filter(|f| f.metadata.files.iter().all(|file| file.base_id == Some(2)))
.collect();
assert_eq!(base2_fragments.len(), 1, "Should have 1 fragment in base2");
}
#[tokio::test]
async fn test_empty_stream_write() {
use lance_io::object_store::ObjectStore;
let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
let data_stream = Box::pin(RecordBatchStreamAdapter::new(
arrow_schema.clone(),
futures::stream::iter(std::iter::empty::<
std::result::Result<RecordBatch, DataFusionError>,
>()),
));
let object_store = Arc::new(ObjectStore::memory());
let write_params = WriteParams {
mode: WriteMode::Create,
..Default::default()
};
let result = write_fragments_internal(
None,
object_store,
&Path::from("test_empty"),
schema,
data_stream,
write_params,
None,
)
.await;
match result {
Ok((fragments, _)) => {
assert!(
fragments.is_empty(),
"Empty stream should create no fragments"
);
}
Err(e) => {
panic!("Expected write empty stream success, got error: {}", e);
}
}
}
#[tokio::test]
async fn test_schema_mismatch_on_append() {
use arrow_array::record_batch;
let batch1 = record_batch!(
("id", Int32, [1, 2, 3, 4, 5]),
("value", Int32, [10, 20, 30, 40, 50])
)
.unwrap();
let dataset = InsertBuilder::new("memory://")
.with_params(&WriteParams {
mode: WriteMode::Create,
..Default::default()
})
.execute(vec![batch1])
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 5);
assert_eq!(dataset.schema().fields.len(), 2);
let batch2 = record_batch!(
("id", Int32, [6, 7, 8]),
("value", Float64, [60.0, 70.0, 80.0])
)
.unwrap();
let result = InsertBuilder::new(Arc::new(dataset.clone()))
.with_params(&WriteParams {
mode: WriteMode::Append,
..Default::default()
})
.execute(vec![batch2])
.await;
assert!(result.is_err(), "Append with mismatched schema should fail");
let error = result.unwrap_err();
let error_msg = error.to_string().to_lowercase();
assert!(
error_msg.contains("schema")
|| error_msg.contains("type")
|| error_msg.contains("mismatch")
|| error_msg.contains("field")
|| error_msg.contains("not found"),
"Error should mention schema or type mismatch: {}",
error_msg
);
assert_eq!(dataset.count_rows(None).await.unwrap(), 5);
assert_eq!(dataset.schema().fields.len(), 2);
}
#[tokio::test]
async fn test_disk_full_error() {
use std::io::{self, ErrorKind};
use std::sync::Arc;
use async_trait::async_trait;
use object_store::{
GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, PutMultipartOptions,
PutOptions, PutPayload, PutResult,
};
#[derive(Debug)]
struct DiskFullObjectStore;
impl std::fmt::Display for DiskFullObjectStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DiskFullObjectStore")
}
}
#[async_trait]
impl object_store::ObjectStore for DiskFullObjectStore {
async fn put(
&self,
_location: &object_store::path::Path,
_bytes: PutPayload,
) -> object_store::Result<PutResult> {
Err(object_store::Error::Generic {
store: "DiskFullStore",
source: Box::new(io::Error::new(
ErrorKind::StorageFull,
"No space left on device",
)),
})
}
async fn put_opts(
&self,
_location: &object_store::path::Path,
_bytes: PutPayload,
_opts: PutOptions,
) -> object_store::Result<PutResult> {
Err(object_store::Error::Generic {
store: "DiskFullStore",
source: Box::new(io::Error::new(
ErrorKind::StorageFull,
"No space left on device",
)),
})
}
async fn put_multipart(
&self,
_location: &object_store::path::Path,
) -> object_store::Result<Box<dyn MultipartUpload>> {
Err(object_store::Error::NotSupported {
source: "Multipart upload not supported".into(),
})
}
async fn put_multipart_opts(
&self,
_location: &object_store::path::Path,
_opts: PutMultipartOptions,
) -> object_store::Result<Box<dyn MultipartUpload>> {
Err(object_store::Error::NotSupported {
source: "Multipart upload not supported".into(),
})
}
async fn get(
&self,
_location: &object_store::path::Path,
) -> object_store::Result<GetResult> {
Err(object_store::Error::NotFound {
path: "".into(),
source: "".into(),
})
}
async fn get_opts(
&self,
_location: &object_store::path::Path,
_options: GetOptions,
) -> object_store::Result<GetResult> {
Err(object_store::Error::NotFound {
path: "".into(),
source: "".into(),
})
}
async fn delete(
&self,
_location: &object_store::path::Path,
) -> object_store::Result<()> {
Ok(())
}
fn list(
&self,
_prefix: Option<&object_store::path::Path>,
) -> futures::stream::BoxStream<'static, object_store::Result<ObjectMeta>> {
Box::pin(futures::stream::empty())
}
async fn list_with_delimiter(
&self,
_prefix: Option<&object_store::path::Path>,
) -> object_store::Result<ListResult> {
Ok(ListResult {
common_prefixes: vec![],
objects: vec![],
})
}
async fn copy(
&self,
_from: &object_store::path::Path,
_to: &object_store::path::Path,
) -> object_store::Result<()> {
Ok(())
}
async fn copy_if_not_exists(
&self,
_from: &object_store::path::Path,
_to: &object_store::path::Path,
) -> object_store::Result<()> {
Ok(())
}
}
let object_store = Arc::new(lance_io::object_store::ObjectStore::new(
Arc::new(DiskFullObjectStore) as Arc<dyn object_store::ObjectStore>,
url::Url::parse("mock:///test").unwrap(),
None,
None,
false,
true,
lance_io::object_store::DEFAULT_LOCAL_IO_PARALLELISM,
lance_io::object_store::DEFAULT_DOWNLOAD_RETRY_COUNT,
None,
));
let arrow_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new(
"id",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
arrow_schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
)
.unwrap();
let data_reader = Box::new(RecordBatchIterator::new(
vec![Ok(batch)].into_iter(),
arrow_schema.clone(),
));
let data_stream = Box::pin(RecordBatchStreamAdapter::new(
arrow_schema,
futures::stream::iter(data_reader.map(|rb| rb.map_err(DataFusionError::from))),
));
let schema = Schema::try_from(data_stream.schema().as_ref()).unwrap();
let write_params = WriteParams {
mode: WriteMode::Create,
..Default::default()
};
let result = write_fragments_internal(
None,
object_store,
&Path::from("test_disk_full"),
schema,
data_stream,
write_params,
None,
)
.await;
assert!(result.is_err(), "Write should fail when disk is full");
let error = result.unwrap_err();
let error_msg = error.to_string().to_lowercase();
assert!(
error_msg.contains("io")
|| error_msg.contains("space")
|| error_msg.contains("storage")
|| error_msg.contains("full"),
"Error should mention IO, space, or storage: {}",
error_msg
);
assert!(
matches!(error, lance_core::Error::IO { .. }),
"Expected IO error, got: {}",
error
);
}
#[tokio::test]
async fn test_write_interruption_recovery() {
use super::commit::CommitBuilder;
use arrow_array::record_batch;
let temp_dir = TempDir::default();
let dataset_uri = format!("file://{}", temp_dir.std_path().display());
let batch =
record_batch!(("id", Int32, [1, 2, 3]), ("value", Utf8, ["a", "b", "c"])).unwrap();
let dataset = InsertBuilder::new(&dataset_uri)
.execute(vec![batch.clone()])
.await
.unwrap();
assert_eq!(dataset.count_rows(None).await.unwrap(), 3);
let new_batch =
record_batch!(("id", Int32, [4, 5, 6]), ("value", Utf8, ["d", "e", "f"])).unwrap();
let uncommitted_result = InsertBuilder::new(WriteDestination::Dataset(Arc::new(
Dataset::open(&dataset_uri).await.unwrap(),
)))
.with_params(&WriteParams {
mode: WriteMode::Append,
..Default::default()
})
.execute_uncommitted(vec![new_batch])
.await;
assert!(
uncommitted_result.is_ok(),
"Uncommitted write should succeed"
);
let transaction = uncommitted_result.unwrap();
let dataset_before_commit = Dataset::open(&dataset_uri).await.unwrap();
let row_count_before = dataset_before_commit.count_rows(None).await.unwrap();
assert_eq!(
row_count_before, 3,
"Dataset should still have only original 3 rows (uncommitted data not visible)"
);
let commit_result = CommitBuilder::new(&dataset_uri).execute(transaction).await;
commit_result.unwrap();
let dataset_after_commit = Dataset::open(&dataset_uri).await.unwrap();
let row_count_after = dataset_after_commit.count_rows(None).await.unwrap();
assert_eq!(
row_count_after, 6,
"Dataset should have all 6 rows after commit"
);
let mut scanner = dataset_after_commit.scan();
scanner.project(&["id", "value"]).unwrap();
let batches = scanner
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let all_ids: Vec<i32> = batches
.iter()
.flat_map(|batch| {
batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.flatten()
})
.collect();
assert_eq!(
all_ids,
vec![1, 2, 3, 4, 5, 6],
"All data should be correctly written"
);
}
}