use arrow_array::{RecordBatch, UInt32Array};
use futures::StreamExt;
use lance_core::utils::deletion::DeletionVector;
use lance_core::{datatypes::Schema, Error, Result};
use lance_table::format::Fragment;
use lance_table::utils::stream::ReadBatchFutStream;
use snafu::{location, Location};
use super::fragment::FragmentReader;
use super::scanner::get_default_batch_size;
use super::write::{open_writer, GenericWriter};
use super::Dataset;
use crate::dataset::FileFragment;
pub struct Updater {
fragment: FileFragment,
input_stream: ReadBatchFutStream,
last_input: Option<RecordBatch>,
writer: Option<Box<dyn GenericWriter>>,
final_schema: Option<Schema>,
write_schema: Option<Schema>,
finished: bool,
deletion_restorer: DeletionRestorer,
}
impl Updater {
pub(super) fn try_new(
fragment: FileFragment,
reader: FragmentReader,
deletion_vector: DeletionVector,
schemas: Option<(Schema, Schema)>,
batch_size: Option<u32>,
) -> Result<Self> {
let (write_schema, final_schema) = if let Some((write_schema, final_schema)) = schemas {
(Some(write_schema), Some(final_schema))
} else {
(None, None)
};
let legacy_batch_size = reader.legacy_num_rows_in_batch(0);
let batch_size = match (&legacy_batch_size, batch_size) {
(Some(legacy_batch_size), _) => *legacy_batch_size,
(None, Some(user_specified_batch_size)) => user_specified_batch_size,
(None, None) => get_default_batch_size().unwrap_or(1024) as u32,
};
let input_stream = reader.read_all(batch_size)?;
Ok(Self {
fragment,
input_stream,
last_input: None,
writer: None,
write_schema,
final_schema,
finished: false,
deletion_restorer: DeletionRestorer::new(deletion_vector, legacy_batch_size),
})
}
pub fn fragment(&self) -> &FileFragment {
&self.fragment
}
pub fn dataset(&self) -> &Dataset {
self.fragment.dataset()
}
pub async fn next(&mut self) -> Result<Option<&RecordBatch>> {
if self.finished {
return Ok(None);
}
let batch = self.input_stream.next().await;
match batch {
None => {
if !self.deletion_restorer.is_exhausted() {
return Err(Error::NotSupported {
source: "Missing too many rows in merge, run compaction to materialize deletions first".into(),
location: location!(),
});
}
self.finished = true;
Ok(None)
}
Some(batch) => {
self.last_input = Some(batch.await?);
Ok(self.last_input.as_ref())
}
}
}
async fn new_writer(&mut self, schema: Schema) -> Result<Box<dyn GenericWriter>> {
let data_storage_version = self
.dataset()
.manifest()
.data_storage_format
.lance_file_version()?;
open_writer(
&self.fragment.dataset().object_store,
&schema,
&self.fragment.dataset().base,
data_storage_version,
)
.await
}
pub async fn update(&mut self, batch: RecordBatch) -> Result<()> {
let Some(last) = self.last_input.as_ref() else {
return Err(Error::io(
"Fragment Updater: no input data is available before update".to_string(),
location!(),
));
};
if last.num_rows() != batch.num_rows() {
return Err(Error::io(
format!(
"Fragment Updater: new batch has different size with the source batch: {} != {}",
last.num_rows(),
batch.num_rows()
),
location!(),
));
};
let batch = self.deletion_restorer.restore(batch)?;
if self.writer.is_none() {
if self.write_schema.is_none() {
let output_schema = batch.schema();
let mut final_schema = self.fragment.schema().merge(output_schema.as_ref())?;
final_schema.set_field_id(Some(self.fragment.dataset().manifest.max_field_id()));
self.final_schema = Some(final_schema);
self.final_schema.as_ref().unwrap().validate()?;
self.write_schema = Some(
self.final_schema
.as_ref()
.unwrap()
.project_by_schema(output_schema.as_ref())?,
);
}
self.writer = Some(
self.new_writer(self.write_schema.as_ref().unwrap().clone())
.await?,
);
}
let writer = self.writer.as_mut().unwrap();
writer.write(&[batch]).await?;
Ok(())
}
pub async fn finish(&mut self) -> Result<Fragment> {
if let Some(writer) = self.writer.as_mut() {
let (_, data_file) = writer.finish().await?;
self.fragment.metadata.files.push(data_file);
}
Ok(self.fragment.metadata().clone())
}
pub fn schema(&self) -> Option<&Schema> {
self.final_schema.as_ref()
}
}
struct DeletionRestorer {
current_row_id: u32,
legacy_batch_size: Option<u32>,
deletion_vector_iter: Option<Box<dyn Iterator<Item = u32> + Send>>,
last_deleted_row_id: Option<u32>,
}
impl DeletionRestorer {
fn new(deletion_vector: DeletionVector, legacy_batch_size: Option<u32>) -> Self {
Self {
current_row_id: 0,
legacy_batch_size,
deletion_vector_iter: Some(deletion_vector.into_sorted_iter()),
last_deleted_row_id: None,
}
}
fn is_exhausted(&self) -> bool {
self.deletion_vector_iter.is_none()
}
fn is_full(batch_size: Option<u32>, num_rows: u32) -> bool {
if let Some(legacy_batch_size) = batch_size {
debug_assert!(legacy_batch_size >= num_rows);
legacy_batch_size == num_rows
} else {
false
}
}
fn deleted_batch_offsets_in_range(&mut self, mut num_rows: u32) -> Vec<u32> {
let mut deleted = Vec::new();
let first_row_id = self.current_row_id;
let mut last_row_id = first_row_id + num_rows;
if self.deletion_vector_iter.is_none() {
return deleted;
}
let deletion_vector_iter = self.deletion_vector_iter.as_mut().unwrap();
let mut next_deleted_id = if self.last_deleted_row_id.is_some() {
self.last_deleted_row_id
} else {
deletion_vector_iter.next()
};
loop {
if let Some(next_deleted_id) = next_deleted_id {
if next_deleted_id > last_row_id
|| (next_deleted_id == last_row_id
&& Self::is_full(self.legacy_batch_size, num_rows))
{
self.last_deleted_row_id = Some(next_deleted_id);
return deleted;
}
deleted.push(next_deleted_id - first_row_id);
last_row_id += 1;
num_rows += 1;
} else {
self.deletion_vector_iter = None;
return deleted;
}
next_deleted_id = deletion_vector_iter.next();
}
}
fn restore(&mut self, batch: RecordBatch) -> Result<RecordBatch> {
let deleted_batch_offsets = self.deleted_batch_offsets_in_range(batch.num_rows() as u32);
let batch = add_blanks(batch, &deleted_batch_offsets)?;
if let Some(batch_size) = self.legacy_batch_size {
let is_last = self.is_exhausted();
if batch.num_rows() != batch_size as usize && !is_last {
return Err(Error::Internal {
message: format!(
"Fragment Updater: batch size mismatch: {} != {}",
batch.num_rows(),
batch_size
),
location: location!(),
});
}
}
self.current_row_id += batch.num_rows() as u32;
Ok(batch)
}
}
pub(crate) fn add_blanks(batch: RecordBatch, batch_offsets: &[u32]) -> Result<RecordBatch> {
if batch_offsets.is_empty() {
return Ok(batch);
}
if batch.num_rows() == 0 {
return Err(Error::NotSupported {
source: "Missing too many rows in merge, run compaction to materialize deletions first"
.into(),
location: location!(),
});
}
let mut selection_vector = Vec::<u32>::with_capacity(batch.num_rows() + batch_offsets.len());
let mut batch_pos = 0;
let mut next_id = 0;
for batch_offset in batch_offsets {
let num_rows = *batch_offset - next_id;
selection_vector.extend(batch_pos..batch_pos + num_rows);
selection_vector.push(0);
next_id = *batch_offset + 1;
batch_pos += num_rows;
}
selection_vector.extend(batch_pos..batch.num_rows() as u32);
let selection_vector = UInt32Array::from(selection_vector);
let arrays = batch
.columns()
.iter()
.map(|array| {
arrow::compute::take(array.as_ref(), &selection_vector, None).map_err(|e| {
Error::Arrow {
message: format!("Failed to add blanks: {}", e),
location: location!(),
}
})
})
.collect::<Result<Vec<_>>>()?;
let batch = RecordBatch::try_new(batch.schema(), arrays)?;
Ok(batch)
}
#[cfg(test)]
mod tests {
use arrow::{array::AsArray, datatypes::Int32Type};
use lance_datagen::RowCount;
use super::add_blanks;
#[test]
fn test_restore_deletes() {
for batch_size in &[None, Some(10)] {
let mut restorer = super::DeletionRestorer::new(
vec![11, 12, 19, 20, 25].into_iter().collect(),
*batch_size,
);
let batch = lance_datagen::gen()
.col("x", lance_datagen::array::step::<Int32Type>())
.into_batch_rows(RowCount::from(10))
.unwrap();
let restored = restorer.restore(batch.clone()).unwrap();
assert_eq!(restored, batch);
let batch = lance_datagen::gen()
.col("x", lance_datagen::array::step::<Int32Type>())
.into_batch_rows(RowCount::from(7))
.unwrap();
let restored = restorer.restore(batch).unwrap();
let values = restored.column(0).as_primitive::<Int32Type>();
assert_eq!(values.value(0), 0);
assert_eq!(values.value(1), 0);
assert_eq!(values.value(2), 0);
assert_eq!(values.value(3), 1);
assert_eq!(values.value(4), 2);
assert_eq!(values.value(5), 3);
assert_eq!(values.value(6), 4);
assert_eq!(values.value(7), 5);
assert_eq!(values.value(8), 6);
assert_eq!(values.value(9), 0);
if *batch_size == Some(10) {
assert_eq!(values.len(), 10);
} else {
assert_eq!(values.value(10), 0);
assert_eq!(values.len(), 11);
}
}
}
#[test]
fn test_add_blanks() {
let batch = lance_datagen::gen()
.col("x", lance_datagen::array::step::<Int32Type>())
.into_batch_rows(RowCount::from(10))
.unwrap();
let with_blanks = add_blanks(batch.clone(), &[5, 7]).unwrap();
assert_eq!(with_blanks.num_rows(), 12);
let values = with_blanks.column(0).as_primitive::<Int32Type>();
for i in 0..5 {
assert_eq!(values.value(i), i as i32);
}
assert_eq!(values.value(5), 0);
assert_eq!(values.value(6), 5);
assert_eq!(values.value(7), 0);
for i in 8..12 {
assert_eq!(values.value(i), (i - 2) as i32);
}
let with_blanks = add_blanks(batch, &[0, 11]).unwrap();
let values = with_blanks.column(0).as_primitive::<Int32Type>();
assert_eq!(values.value(0), 0);
for i in 1..11 {
assert_eq!(values.value(i), (i - 1) as i32);
}
assert_eq!(values.value(11), 0);
}
}