use std::mem::MaybeUninit;
use std::ops::Range;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow_array::cast::AsArray;
use arrow_array::types::Float32Type;
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use lance_core::{Error, ROW_ID, Result};
use lance_linalg::distance::{DistanceType, Dot, L2, cosine_distance};
use super::graph::ScoredPoint;
pub const FLAT_COLUMN: &str = "flat";
#[derive(Clone)]
struct StoredArrowBatch {
_array: Arc<FixedSizeListArray>,
values_ptr: *const f32,
}
#[derive(Copy, Clone)]
struct RowLookup {
batch_idx: u32,
offset: u32,
}
pub trait VectorSource: Send + Sync {
fn len(&self) -> usize;
fn dim(&self) -> usize;
fn distance_type(&self) -> DistanceType;
fn row_id(&self, id: u32) -> u64;
fn vector(&self, id: u32) -> &[f32];
fn is_empty(&self) -> bool {
self.len() == 0
}
fn distance_to(&self, query: &[f32], id: u32) -> f32 {
compute_f32_distance(query, self.vector(id), self.distance_type())
}
fn distance_between(&self, left: u32, right: u32) -> f32 {
compute_f32_distance(self.vector(left), self.vector(right), self.distance_type())
}
fn prefers_candidate(&self, candidate: ScoredPoint, selected: &[ScoredPoint]) -> bool {
selected
.iter()
.all(|other| candidate.distance < self.distance_between(candidate.id, other.id))
}
}
pub fn compute_f32_distance(query: &[f32], vector: &[f32], distance_type: DistanceType) -> f32 {
match distance_type {
DistanceType::L2 => f32::l2(query, vector),
DistanceType::Dot => f32::dot(query, vector),
DistanceType::Cosine => cosine_distance(query, vector),
DistanceType::Hamming => f32::INFINITY,
}
}
pub struct ArrowFixedSizeListVectorStore {
batches: *mut MaybeUninit<StoredArrowBatch>,
row_to_batch: *mut MaybeUninit<RowLookup>,
row_ids: *mut MaybeUninit<u64>,
committed_batches: AtomicUsize,
committed_len: AtomicUsize,
capacity: usize,
max_batches: usize,
dim: usize,
distance_type: DistanceType,
schema: SchemaRef,
}
unsafe impl Send for ArrowFixedSizeListVectorStore {}
unsafe impl Sync for ArrowFixedSizeListVectorStore {}
impl Drop for ArrowFixedSizeListVectorStore {
fn drop(&mut self) {
unsafe {
let committed_batches = self.committed_batches.load(Ordering::Acquire);
for idx in 0..committed_batches {
std::ptr::drop_in_place(self.batches.add(idx).cast::<StoredArrowBatch>());
}
let _: Box<[MaybeUninit<StoredArrowBatch>]> = Box::from_raw(
std::ptr::slice_from_raw_parts_mut(self.batches, self.max_batches),
);
let _: Box<[MaybeUninit<RowLookup>]> = Box::from_raw(
std::ptr::slice_from_raw_parts_mut(self.row_to_batch, self.capacity),
);
let _: Box<[MaybeUninit<u64>]> = Box::from_raw(std::ptr::slice_from_raw_parts_mut(
self.row_ids,
self.capacity,
));
}
}
}
impl ArrowFixedSizeListVectorStore {
pub fn try_new(
capacity: usize,
max_batches: usize,
dim: usize,
distance_type: DistanceType,
) -> Result<Self> {
if capacity == 0 {
return Err(Error::invalid_input("capacity must be greater than 0"));
}
if max_batches == 0 {
return Err(Error::invalid_input("max_batches must be greater than 0"));
}
if dim == 0 {
return Err(Error::invalid_input("dim must be greater than 0"));
}
if capacity > u32::MAX as usize {
return Err(Error::invalid_input(format!(
"capacity must fit in u32, got {capacity}"
)));
}
if max_batches > u32::MAX as usize {
return Err(Error::invalid_input(format!(
"max_batches must fit in u32, got {max_batches}"
)));
}
if distance_type == DistanceType::Hamming {
return Err(Error::invalid_input(
"ArrowFixedSizeListVectorStore stores f32 vectors and does not support hamming distance",
));
}
let batches: Box<[MaybeUninit<StoredArrowBatch>]> = uninit_boxed_slice(max_batches);
let row_to_batch: Box<[MaybeUninit<RowLookup>]> = uninit_boxed_slice(capacity);
let row_ids: Box<[MaybeUninit<u64>]> = uninit_boxed_slice(capacity);
let schema = Arc::new(ArrowSchema::new(vec![
Field::new(ROW_ID, DataType::UInt64, false),
Field::new(
FLAT_COLUMN,
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
),
false,
),
]));
Ok(Self {
batches: Box::into_raw(batches) as *mut MaybeUninit<StoredArrowBatch>,
row_to_batch: Box::into_raw(row_to_batch) as *mut MaybeUninit<RowLookup>,
row_ids: Box::into_raw(row_ids) as *mut MaybeUninit<u64>,
committed_batches: AtomicUsize::new(0),
committed_len: AtomicUsize::new(0),
capacity,
max_batches,
dim,
distance_type,
schema,
})
}
pub fn committed_len(&self) -> usize {
self.committed_len.load(Ordering::Acquire)
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn append_batch(
&self,
vectors: Arc<FixedSizeListArray>,
row_id_start: u64,
) -> Result<Range<u32>> {
if vectors.is_empty() {
let start = self.committed_len.load(Ordering::Relaxed) as u32;
return Ok(start..start);
}
if vectors.value_length() as usize != self.dim {
return Err(Error::invalid_input(format!(
"vector dimension mismatch: expected {}, got {}",
self.dim,
vectors.value_length()
)));
}
if vectors.null_count() > 0 {
return Err(Error::invalid_input(format!(
"null vectors are not supported, got {} null row(s)",
vectors.null_count()
)));
}
let values = vectors.values();
let Some(values_f32) = values.as_primitive_opt::<Float32Type>() else {
return Err(Error::invalid_input(format!(
"vector values must be Float32, got {:?}",
values.data_type()
)));
};
let start = self.committed_len.load(Ordering::Relaxed);
let end = start.checked_add(vectors.len()).ok_or_else(|| {
Error::invalid_input(format!(
"vector count overflow: start={}, batch_len={}",
start,
vectors.len()
))
})?;
if end > self.capacity {
return Err(Error::invalid_input(format!(
"capacity {} exhausted: inserting rows [{}..{})",
self.capacity, start, end
)));
}
let batch_idx = self.committed_batches.load(Ordering::Relaxed);
if batch_idx >= self.max_batches {
return Err(Error::invalid_input(format!(
"max_batches {} exhausted",
self.max_batches
)));
}
unsafe {
self.batches
.add(batch_idx)
.write(MaybeUninit::new(StoredArrowBatch {
_array: vectors.clone(),
values_ptr: values_f32.values().as_ptr(),
}));
for offset in 0..vectors.len() {
self.row_to_batch
.add(start + offset)
.write(MaybeUninit::new(RowLookup {
batch_idx: batch_idx as u32,
offset: offset as u32,
}));
self.row_ids
.add(start + offset)
.write(MaybeUninit::new(row_id_start + offset as u64));
}
}
self.committed_batches
.store(batch_idx + 1, Ordering::Release);
self.committed_len.store(end, Ordering::Release);
Ok(start as u32..end as u32)
}
pub fn snapshot(self: &Arc<Self>) -> VectorStoreSnapshot {
let committed_batches = self.committed_batches.load(Ordering::Acquire);
let contiguous_values_addr = if committed_batches == 1 {
unsafe { (*self.batches.cast::<StoredArrowBatch>()).values_ptr as usize }
} else {
0
};
VectorStoreSnapshot {
store: self.clone(),
visible_len: self.committed_len(),
contiguous_values_addr,
}
}
pub fn to_record_batch(&self, total_rows: Option<u64>) -> Result<RecordBatch> {
let visible_len = self.committed_len();
self.to_record_batch_with_len(visible_len, total_rows)
}
fn to_record_batch_with_len(
&self,
visible_len: usize,
total_rows: Option<u64>,
) -> Result<RecordBatch> {
let mut row_ids = Vec::with_capacity(visible_len);
let mut values = Vec::with_capacity(visible_len * self.dim);
for id in 0..visible_len as u32 {
let row_id = self.row_id_at(id);
row_ids.push(match total_rows {
Some(total_rows) => total_rows.checked_sub(row_id + 1).ok_or_else(|| {
Error::invalid_input(format!(
"row id reversal underflow: total_rows={total_rows}, row_id={row_id}"
))
})?,
None => row_id,
});
values.extend_from_slice(self.vector_at(id));
}
let row_ids = Arc::new(UInt64Array::from(row_ids)) as ArrayRef;
let values = Arc::new(Float32Array::from(values)) as ArrayRef;
let field = Arc::new(Field::new("item", DataType::Float32, true));
let vectors = Arc::new(FixedSizeListArray::try_new(
field,
self.dim as i32,
values,
None,
)?) as ArrayRef;
Ok(RecordBatch::try_new(
self.schema.clone(),
vec![row_ids, vectors],
)?)
}
fn row_id_at(&self, id: u32) -> u64 {
debug_assert!((id as usize) < self.committed_len.load(Ordering::Acquire));
unsafe { self.row_ids.add(id as usize).read().assume_init() }
}
fn vector_at(&self, id: u32) -> &[f32] {
debug_assert!((id as usize) < self.committed_len.load(Ordering::Acquire));
unsafe {
let lookup = self.row_to_batch.add(id as usize).read().assume_init();
let batch = &*self
.batches
.add(lookup.batch_idx as usize)
.cast::<StoredArrowBatch>();
let ptr = batch.values_ptr.add(lookup.offset as usize * self.dim);
std::slice::from_raw_parts(ptr, self.dim)
}
}
}
#[derive(Clone)]
pub struct VectorStoreSnapshot {
store: Arc<ArrowFixedSizeListVectorStore>,
visible_len: usize,
contiguous_values_addr: usize,
}
impl VectorStoreSnapshot {
pub fn to_record_batch(&self, total_rows: Option<u64>) -> Result<RecordBatch> {
self.store
.to_record_batch_with_len(self.visible_len, total_rows)
}
}
impl VectorSource for VectorStoreSnapshot {
fn len(&self) -> usize {
self.visible_len
}
fn dim(&self) -> usize {
self.store.dim
}
fn distance_type(&self) -> DistanceType {
self.store.distance_type
}
fn row_id(&self, id: u32) -> u64 {
debug_assert!((id as usize) < self.visible_len);
self.store.row_id_at(id)
}
fn vector(&self, id: u32) -> &[f32] {
debug_assert!((id as usize) < self.visible_len);
if self.contiguous_values_addr != 0 {
unsafe {
let ptr =
(self.contiguous_values_addr as *const f32).add(id as usize * self.store.dim);
return std::slice::from_raw_parts(ptr, self.store.dim);
}
}
self.store.vector_at(id)
}
}
fn uninit_boxed_slice<T>(len: usize) -> Box<[MaybeUninit<T>]> {
(0..len)
.map(|_| MaybeUninit::uninit())
.collect::<Vec<_>>()
.into_boxed_slice()
}
#[cfg(test)]
mod tests {
use super::*;
fn fsl(values: Vec<f32>, dim: usize) -> Arc<FixedSizeListArray> {
let values = Arc::new(Float32Array::from(values)) as ArrayRef;
Arc::new(
FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
values,
None,
)
.unwrap(),
)
}
#[test]
fn test_arrow_store_reuses_batches() {
let store =
Arc::new(ArrowFixedSizeListVectorStore::try_new(8, 2, 2, DistanceType::L2).unwrap());
let first = fsl(vec![1.0, 2.0, 3.0, 4.0], 2);
let second = fsl(vec![5.0, 6.0], 2);
assert_eq!(store.append_batch(first, 10).unwrap(), 0..2);
assert_eq!(store.append_batch(second, 12).unwrap(), 2..3);
let snapshot = store.snapshot();
assert_eq!(snapshot.len(), 3);
assert_eq!(snapshot.row_id(2), 12);
assert_eq!(snapshot.vector(1), &[3.0, 4.0]);
assert_eq!(
compute_f32_distance(snapshot.vector(0), snapshot.vector(1), DistanceType::L2),
8.0
);
}
}