use arrow_array::{ArrayRef, UInt64Array};
use datafusion::execution::SendableRecordBatchStream;
use futures::TryStreamExt;
use lance_core::error::Error;
use lance_core::utils::address::RowAddress;
use lance_core::utils::mask::RowAddrTreeMap;
use lance_core::{ROW_ADDR, Result};
use lance_datafusion::chunker::chunk_concat_stream;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ZoneBound {
pub fragment_id: u64,
pub start: u64,
pub length: usize,
}
pub trait ZoneProcessor {
type ZoneStatistics;
fn process_chunk(&mut self, values: &ArrayRef) -> Result<()>;
fn finish_zone(&mut self, bound: ZoneBound) -> Result<Self::ZoneStatistics>;
fn reset(&mut self) -> Result<()>;
}
#[derive(Debug)]
pub struct ZoneTrainer<P> {
processor: P,
zone_capacity: u64,
}
impl<P> ZoneTrainer<P>
where
P: ZoneProcessor,
{
pub fn new(processor: P, zone_capacity: u64) -> Result<Self> {
if zone_capacity == 0 {
return Err(Error::invalid_input(
"zone capacity must be greater than zero",
));
}
Ok(Self {
processor,
zone_capacity,
})
}
pub async fn train(
mut self,
stream: SendableRecordBatchStream,
) -> Result<Vec<P::ZoneStatistics>> {
let zone_size = usize::try_from(self.zone_capacity).map_err(|_| {
Error::invalid_input("zone capacity does not fit into usize on this platform")
})?;
let mut batches = chunk_concat_stream(stream, zone_size);
let mut zones = Vec::new();
let mut current_fragment_id: Option<u64> = None;
let mut current_zone_len: usize = 0;
let mut zone_start_offset: Option<u64> = None;
let mut zone_end_offset: Option<u64> = None;
self.processor.reset()?;
while let Some(batch) = batches.try_next().await? {
if batch.num_rows() == 0 {
continue;
}
let values = batch.column(0);
let row_addr_col = batch
.column_by_name(ROW_ADDR)
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let mut batch_offset = 0usize;
while batch_offset < batch.num_rows() {
let row_addr = row_addr_col.value(batch_offset);
let fragment_id = row_addr >> 32;
match current_fragment_id {
Some(current) if current != fragment_id => {
if current_zone_len > 0 {
Self::flush_zone(
&mut self.processor,
&mut zones,
current,
&mut current_zone_len,
&mut zone_start_offset,
&mut zone_end_offset,
)?;
}
current_fragment_id = Some(fragment_id);
}
None => {
current_fragment_id = Some(fragment_id);
}
_ => {}
}
let run_len = (batch_offset..batch.num_rows())
.take_while(|&idx| (row_addr_col.value(idx) >> 32) == fragment_id)
.count();
let capacity = zone_size - current_zone_len;
let take = run_len.min(capacity);
self.processor
.process_chunk(&values.slice(batch_offset, take))?;
let first_offset =
RowAddress::new_from_u64(row_addr_col.value(batch_offset)).row_offset() as u64;
let last_offset =
RowAddress::new_from_u64(row_addr_col.value(batch_offset + take - 1))
.row_offset() as u64;
if zone_start_offset.is_none() {
zone_start_offset = Some(first_offset);
}
zone_end_offset = Some(last_offset);
current_zone_len += take;
batch_offset += take;
if current_zone_len == zone_size {
Self::flush_zone(
&mut self.processor,
&mut zones,
fragment_id,
&mut current_zone_len,
&mut zone_start_offset,
&mut zone_end_offset,
)?;
}
}
}
if current_zone_len > 0 {
if let Some(fragment_id) = current_fragment_id {
Self::flush_zone(
&mut self.processor,
&mut zones,
fragment_id,
&mut current_zone_len,
&mut zone_start_offset,
&mut zone_end_offset,
)?;
} else {
self.processor.reset()?;
}
}
Ok(zones)
}
fn flush_zone(
processor: &mut P,
zones: &mut Vec<P::ZoneStatistics>,
fragment_id: u64,
current_zone_len: &mut usize,
zone_start_offset: &mut Option<u64>,
zone_end_offset: &mut Option<u64>,
) -> Result<()> {
let start = zone_start_offset.unwrap_or(0);
let inferred_end =
zone_end_offset.unwrap_or_else(|| start + (*current_zone_len as u64).saturating_sub(1));
if inferred_end < start {
return Err(Error::invalid_input("zone row offsets are out of order"));
}
let bound = ZoneBound {
fragment_id,
start,
length: (inferred_end - start + 1) as usize,
};
let stats = processor.finish_zone(bound)?;
zones.push(stats);
*current_zone_len = 0;
*zone_start_offset = None;
*zone_end_offset = None;
processor.reset()?;
Ok(())
}
}
pub fn search_zones<T, F>(
zones: &[T],
metrics: &dyn crate::metrics::MetricsCollector,
mut zone_matches: F,
) -> Result<crate::scalar::SearchResult>
where
T: AsRef<ZoneBound>,
F: FnMut(&T) -> Result<bool>,
{
metrics.record_comparisons(zones.len());
let mut row_addr_tree_map = RowAddrTreeMap::new();
for zone in zones {
if zone_matches(zone)? {
let bound = zone.as_ref();
let zone_start_addr = (bound.fragment_id << 32) + bound.start;
let zone_end_addr = zone_start_addr + bound.length as u64;
row_addr_tree_map.insert_range(zone_start_addr..zone_end_addr);
}
}
Ok(crate::scalar::SearchResult::at_most(row_addr_tree_map))
}
pub async fn rebuild_zones<P>(
existing: &[P::ZoneStatistics],
trainer: ZoneTrainer<P>,
stream: SendableRecordBatchStream,
) -> Result<Vec<P::ZoneStatistics>>
where
P: ZoneProcessor,
P::ZoneStatistics: Clone,
{
let mut combined = existing.to_vec();
let mut new_zones = trainer.train(stream).await?;
combined.append(&mut new_zones);
Ok(combined)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{metrics::LocalMetricsCollector, scalar::SearchResult};
use arrow_array::{ArrayRef, Int32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream;
use lance_core::ROW_ADDR;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
struct MockStats {
sum: i32,
bound: ZoneBound,
}
#[derive(Debug)]
struct MockProcessor {
current_sum: i32,
}
impl MockProcessor {
fn new() -> Self {
Self { current_sum: 0 }
}
}
impl ZoneProcessor for MockProcessor {
type ZoneStatistics = MockStats;
fn process_chunk(&mut self, values: &ArrayRef) -> Result<()> {
let arr = values.as_any().downcast_ref::<Int32Array>().unwrap();
self.current_sum += arr.iter().map(|v| v.unwrap_or(0)).sum::<i32>();
Ok(())
}
fn finish_zone(&mut self, bound: ZoneBound) -> Result<Self::ZoneStatistics> {
Ok(MockStats {
sum: self.current_sum,
bound,
})
}
fn reset(&mut self) -> Result<()> {
self.current_sum = 0;
Ok(())
}
}
fn batch(values: Vec<i32>, fragments: Vec<u64>, offsets: Vec<u64>) -> RecordBatch {
let val_array = Arc::new(Int32Array::from(values));
let row_addrs: Vec<u64> = fragments
.into_iter()
.zip(offsets)
.map(|(frag, off)| (frag << 32) | off)
.collect();
let addr_array = Arc::new(UInt64Array::from(row_addrs));
let schema = Arc::new(Schema::new(vec![
Field::new("value", DataType::Int32, false),
Field::new(ROW_ADDR, DataType::UInt64, false),
]));
RecordBatch::try_new(schema, vec![val_array, addr_array]).unwrap()
}
#[tokio::test]
async fn splits_single_fragment() {
let values = vec![1; 10];
let offsets: Vec<u64> = (0..10).collect();
let batch = batch(values, vec![0; 10], offsets);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 4).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 3);
assert_eq!(stats[0].bound.start, 0);
assert_eq!(stats[0].bound.length, 4);
assert_eq!(stats[1].bound.start, 4);
assert_eq!(stats[1].bound.length, 4);
assert_eq!(stats[2].bound.start, 8);
assert_eq!(stats[2].bound.length, 2); assert_eq!(
stats.iter().map(|s| s.sum).collect::<Vec<_>>(),
vec![4, 4, 2]
);
}
#[tokio::test]
async fn flushes_on_fragment_boundary() {
let values = vec![1, 1, 1, 2, 2, 2];
let fragments = vec![0, 0, 0, 1, 1, 1];
let offsets = vec![0, 1, 2, 0, 1, 2];
let batch = batch(values, fragments, offsets);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 10).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 2);
assert_eq!(stats[0].bound.fragment_id, 0);
assert_eq!(stats[0].bound.length, 3); assert_eq!(stats[1].bound.fragment_id, 1);
assert_eq!(stats[1].bound.length, 3); }
#[tokio::test]
async fn errors_on_out_of_order_offsets() {
let values = vec![1, 2, 3];
let fragments = vec![0, 0, 0];
let offsets = vec![5, 3, 4];
let batch = batch(values, fragments, offsets);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 10).unwrap();
let err = trainer.train(stream).await.unwrap_err();
assert!(
format!("{}", err).contains("zone row offsets are out of order"),
"unexpected error: {err:?}"
);
}
#[tokio::test]
async fn handles_empty_batches() {
let schema = Arc::new(Schema::new(vec![
Field::new("value", DataType::Int32, false),
Field::new(ROW_ADDR, DataType::UInt64, false),
]));
let empty_batch = RecordBatch::new_empty(schema.clone());
let valid_batch = batch(vec![1, 2, 3], vec![0, 0, 0], vec![0, 1, 2]);
let stream = Box::pin(RecordBatchStreamAdapter::new(
schema,
stream::iter(vec![
Ok(empty_batch.clone()),
Ok(valid_batch),
Ok(empty_batch),
]),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 10).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].sum, 6);
assert_eq!(stats[0].bound.fragment_id, 0);
assert_eq!(stats[0].bound.length, 3);
}
#[tokio::test]
async fn handles_zone_capacity_one() {
let values = vec![10, 20, 30];
let offsets = vec![0, 1, 2];
let batch = batch(values.clone(), vec![0, 0, 0], offsets.clone());
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 1).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 3);
for (i, stat) in stats.iter().enumerate() {
assert_eq!(stat.bound.fragment_id, 0);
assert_eq!(stat.bound.start, offsets[i]);
assert_eq!(stat.bound.length, 1); assert_eq!(stat.sum, values[i]);
}
}
#[tokio::test]
async fn handles_large_capacity() {
let values = vec![1; 100];
let offsets: Vec<u64> = (0..100).collect();
let batch = batch(values, vec![0; 100], offsets);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 10000).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].sum, 100);
assert_eq!(stats[0].bound.start, 0);
assert_eq!(stats[0].bound.length, 100);
}
#[tokio::test]
async fn rejects_zero_capacity() {
let processor = MockProcessor::new();
let result = ZoneTrainer::new(processor, 0);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("zone capacity must be greater than zero")
);
}
#[tokio::test]
async fn handles_multiple_batches_same_fragment() {
let b1 = batch(vec![1, 1], vec![0, 0], vec![0, 1]);
let b2 = batch(vec![1, 1], vec![0, 0], vec![2, 3]);
let b3 = batch(vec![1, 1], vec![0, 0], vec![4, 5]);
let stream = Box::pin(RecordBatchStreamAdapter::new(
b1.schema(),
stream::iter(vec![Ok(b1), Ok(b2), Ok(b3)]),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 4).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 2);
assert_eq!(stats[0].bound.fragment_id, 0);
assert_eq!(stats[0].bound.start, 0);
assert_eq!(stats[0].bound.length, 4);
assert_eq!(stats[0].sum, 4);
assert_eq!(stats[1].bound.fragment_id, 0);
assert_eq!(stats[1].bound.start, 4);
assert_eq!(stats[1].bound.length, 2);
assert_eq!(stats[1].sum, 2);
}
#[tokio::test]
async fn handles_multi_batch_with_fragment_change() {
let b1 = batch(vec![1, 1], vec![0, 0], vec![0, 1]);
let b2 = batch(vec![1, 1, 2, 2], vec![0, 0, 1, 1], vec![2, 3, 0, 1]);
let stream = Box::pin(RecordBatchStreamAdapter::new(
b1.schema(),
stream::iter(vec![Ok(b1), Ok(b2)]),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 3).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 3);
assert_eq!(stats[0].bound.fragment_id, 0);
assert_eq!(stats[0].bound.start, 0);
assert_eq!(stats[0].bound.length, 3);
assert_eq!(stats[0].sum, 3);
assert_eq!(stats[1].bound.fragment_id, 0);
assert_eq!(stats[1].bound.start, 3);
assert_eq!(stats[1].bound.length, 1);
assert_eq!(stats[1].sum, 1);
assert_eq!(stats[2].bound.fragment_id, 1);
assert_eq!(stats[2].bound.start, 0);
assert_eq!(stats[2].bound.length, 2);
assert_eq!(stats[2].sum, 4);
}
#[tokio::test]
async fn handles_non_contiguous_offsets_after_deletion() {
let values = vec![1, 1, 1, 1, 1, 1]; let fragments = vec![0, 0, 0, 0, 0, 0];
let offsets = vec![0, 1, 5, 7, 8, 9];
let batch = batch(values, fragments, offsets);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 4).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 2);
assert_eq!(stats[0].sum, 4);
assert_eq!(stats[0].bound.fragment_id, 0);
assert_eq!(stats[0].bound.start, 0);
assert_eq!(stats[0].bound.length, 8);
assert_eq!(stats[1].sum, 2);
assert_eq!(stats[1].bound.fragment_id, 0);
assert_eq!(stats[1].bound.start, 8);
assert_eq!(stats[1].bound.length, 2); }
#[tokio::test]
async fn handles_deletion_with_large_gaps() {
let values = vec![1, 1, 1];
let fragments = vec![0, 0, 0];
let offsets = vec![0, 100, 200];
let batch = batch(values, fragments, offsets);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 10).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].sum, 3);
assert_eq!(stats[0].bound.start, 0);
assert_eq!(stats[0].bound.length, 201); }
#[tokio::test]
async fn handles_non_contiguous_fragment_ids() {
let values = vec![1, 1, 2, 2, 3, 3];
let fragments = vec![0, 0, 5, 5, 10, 10]; let offsets = vec![0, 1, 0, 1, 0, 1];
let batch = batch(values, fragments, offsets);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let processor = MockProcessor::new();
let trainer = ZoneTrainer::new(processor, 10).unwrap();
let stats = trainer.train(stream).await.unwrap();
assert_eq!(stats.len(), 3);
assert_eq!(stats[0].bound.fragment_id, 0);
assert_eq!(stats[0].bound.start, 0);
assert_eq!(stats[0].bound.length, 2);
assert_eq!(stats[0].sum, 2);
assert_eq!(stats[1].bound.fragment_id, 5);
assert_eq!(stats[1].bound.start, 0);
assert_eq!(stats[1].bound.length, 2);
assert_eq!(stats[1].sum, 4);
assert_eq!(stats[2].bound.fragment_id, 10);
assert_eq!(stats[2].bound.start, 0);
assert_eq!(stats[2].bound.length, 2);
assert_eq!(stats[2].sum, 6);
}
#[test]
fn search_zones_collects_row_ranges() {
#[derive(Debug)]
struct DummyZone {
bound: ZoneBound,
matches: bool,
}
impl AsRef<ZoneBound> for DummyZone {
fn as_ref(&self) -> &ZoneBound {
&self.bound
}
}
let zones = vec![
DummyZone {
bound: ZoneBound {
fragment_id: 0,
start: 0,
length: 2,
},
matches: true,
},
DummyZone {
bound: ZoneBound {
fragment_id: 1,
start: 5,
length: 3,
},
matches: false,
},
DummyZone {
bound: ZoneBound {
fragment_id: 2,
start: 10,
length: 1,
},
matches: true,
},
];
let metrics = LocalMetricsCollector::default();
let result = search_zones(&zones, &metrics, |zone| Ok(zone.matches)).unwrap();
let SearchResult::AtMost(map) = result else {
panic!("search_zones should return AtMost for dummy zones");
};
assert!(map.selected(0));
assert!(map.selected(1));
assert!(!map.selected((1_u64 << 32) + 5));
assert!(!map.selected((1_u64 << 32) + 7));
assert!(map.selected((2_u64 << 32) + 10));
assert!(!map.selected((2_u64 << 32) + 11));
}
#[test]
fn search_zones_returns_empty_when_no_match() {
#[derive(Debug)]
struct DummyZone {
bound: ZoneBound,
matches: bool,
}
impl AsRef<ZoneBound> for DummyZone {
fn as_ref(&self) -> &ZoneBound {
&self.bound
}
}
let zones = vec![
DummyZone {
bound: ZoneBound {
fragment_id: 0,
start: 0,
length: 4,
},
matches: false,
},
DummyZone {
bound: ZoneBound {
fragment_id: 1,
start: 10,
length: 2,
},
matches: false,
},
];
let metrics = LocalMetricsCollector::default();
let result = search_zones(&zones, &metrics, |zone| Ok(zone.matches)).unwrap();
let SearchResult::AtMost(map) = result else {
panic!("expected AtMost result");
};
assert!(map.is_empty());
}
#[tokio::test]
async fn rebuild_zones_appends_new_stats() {
let existing = vec![MockStats {
sum: 50,
bound: ZoneBound {
fragment_id: 0,
start: 0,
length: 2,
},
}];
let batch = batch(vec![3, 4], vec![1, 1], vec![0, 1]);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let trainer = ZoneTrainer::new(MockProcessor::new(), 2).unwrap();
let rebuilt = rebuild_zones(&existing, trainer, stream).await.unwrap();
assert_eq!(rebuilt.len(), 2);
assert_eq!(rebuilt[0].sum, 50);
assert_eq!(rebuilt[1].sum, 7);
assert_eq!(rebuilt[1].bound.fragment_id, 1);
assert_eq!(rebuilt[1].bound.start, 0);
assert_eq!(rebuilt[1].bound.length, 2);
}
#[tokio::test]
async fn rebuild_zones_handles_multi_fragment_stream() {
let existing = vec![MockStats {
sum: 10,
bound: ZoneBound {
fragment_id: 0,
start: 0,
length: 1,
},
}];
let batch = batch(vec![5, 5, 6, 6], vec![1, 1, 2, 2], vec![0, 1, 0, 1]);
let stream = Box::pin(RecordBatchStreamAdapter::new(
batch.schema(),
stream::once(async { Ok(batch) }),
));
let trainer = ZoneTrainer::new(MockProcessor::new(), 2).unwrap();
let rebuilt = rebuild_zones(&existing, trainer, stream).await.unwrap();
assert_eq!(rebuilt.len(), 3);
assert_eq!(rebuilt[0].bound.fragment_id, 0);
assert_eq!(rebuilt[1].bound.fragment_id, 1);
assert_eq!(rebuilt[2].bound.fragment_id, 2);
assert_eq!(rebuilt[1].sum, 10);
assert_eq!(rebuilt[2].sum, 12);
}
}