use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow_array::cast::AsArray;
use arrow_array::types::UInt8Type;
use arrow_array::{Array, FixedSizeListArray, RecordBatch, UInt8Array};
use crossbeam_skiplist::SkipMap;
use lance_core::{Error, Result};
use lance_index::vector::ivf::storage::IvfModel;
use lance_index::vector::kmeans::compute_partitions_arrow_array;
use lance_index::vector::pq::ProductQuantizer;
use lance_index::vector::pq::storage::transpose;
use lance_index::vector::quantizer::Quantization;
use lance_linalg::distance::DistanceType;
use crate::dataset::mem_wal::memtable::batch_store::StoredBatch;
pub use super::RowPosition;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PartitionFull;
impl std::fmt::Display for PartitionFull {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "IVF-PQ partition store is full")
}
}
impl std::error::Error for PartitionFull {}
#[derive(Debug)]
struct ColumnMajorIvfPqMemPartition {
codes: UnsafeCell<Box<[MaybeUninit<u8>]>>,
row_positions: UnsafeCell<Box<[MaybeUninit<u64>]>>,
committed_len: AtomicUsize,
capacity: usize,
num_sub_vectors: usize,
}
unsafe impl Sync for ColumnMajorIvfPqMemPartition {}
unsafe impl Send for ColumnMajorIvfPqMemPartition {}
impl ColumnMajorIvfPqMemPartition {
fn new(capacity: usize, num_sub_vectors: usize) -> Self {
assert!(capacity > 0, "capacity must be > 0");
assert!(num_sub_vectors > 0, "num_sub_vectors must be > 0");
let codes_size = capacity * num_sub_vectors;
let mut codes = Vec::with_capacity(codes_size);
for _ in 0..codes_size {
codes.push(MaybeUninit::uninit());
}
let mut row_positions = Vec::with_capacity(capacity);
for _ in 0..capacity {
row_positions.push(MaybeUninit::uninit());
}
Self {
codes: UnsafeCell::new(codes.into_boxed_slice()),
row_positions: UnsafeCell::new(row_positions.into_boxed_slice()),
committed_len: AtomicUsize::new(0),
capacity,
num_sub_vectors,
}
}
#[inline]
fn len(&self) -> usize {
self.committed_len.load(Ordering::Acquire)
}
#[inline]
fn remaining_capacity(&self) -> usize {
self.capacity
.saturating_sub(self.committed_len.load(Ordering::Relaxed))
}
fn append_transposed_batch(
&self,
transposed_codes: &[u8],
positions: &[u64],
) -> std::result::Result<(), PartitionFull> {
let num_vectors = positions.len();
if num_vectors == 0 {
return Ok(());
}
debug_assert_eq!(
transposed_codes.len(),
num_vectors * self.num_sub_vectors,
"transposed_codes length mismatch: expected {}, got {}",
num_vectors * self.num_sub_vectors,
transposed_codes.len()
);
let committed = self.committed_len.load(Ordering::Relaxed);
if committed + num_vectors > self.capacity {
return Err(PartitionFull);
}
let codes = unsafe { &mut *self.codes.get() };
let row_pos = unsafe { &mut *self.row_positions.get() };
for subvec_idx in 0..self.num_sub_vectors {
let src_start = subvec_idx * num_vectors;
let dst_start = subvec_idx * self.capacity + committed;
for i in 0..num_vectors {
codes[dst_start + i].write(transposed_codes[src_start + i]);
}
}
for (i, &pos) in positions.iter().enumerate() {
row_pos[committed + i].write(pos);
}
self.committed_len
.store(committed + num_vectors, Ordering::Release);
Ok(())
}
fn get_codes_for_search(&self) -> (Vec<u8>, Vec<u64>) {
let len = self.committed_len.load(Ordering::Acquire);
if len == 0 {
return (Vec::new(), Vec::new());
}
let codes = unsafe { &*self.codes.get() };
let row_pos = unsafe { &*self.row_positions.get() };
let mut result_codes = Vec::with_capacity(len * self.num_sub_vectors);
for subvec_idx in 0..self.num_sub_vectors {
let start = subvec_idx * self.capacity;
for i in 0..len {
result_codes.push(unsafe { codes[start + i].assume_init() });
}
}
let result_positions: Vec<u64> = (0..len)
.map(|i| unsafe { row_pos[i].assume_init() })
.collect();
(result_codes, result_positions)
}
}
#[derive(Debug)]
pub struct IvfPqMemPartition {
primary: ColumnMajorIvfPqMemPartition,
overflow: SkipMap<u64, Vec<u8>>,
overflow_count: AtomicUsize,
num_sub_vectors: usize,
}
impl IvfPqMemPartition {
pub fn new(capacity: usize, num_sub_vectors: usize) -> Self {
Self {
primary: ColumnMajorIvfPqMemPartition::new(capacity, num_sub_vectors),
overflow: SkipMap::new(),
overflow_count: AtomicUsize::new(0),
num_sub_vectors,
}
}
pub fn append_batch(&self, row_major_codes: &[u8], positions: &[u64]) {
let num_vectors = positions.len();
if num_vectors == 0 {
return;
}
debug_assert_eq!(
row_major_codes.len(),
num_vectors * self.num_sub_vectors,
"row_major_codes length mismatch"
);
let primary_remaining = self.primary.remaining_capacity();
if primary_remaining >= num_vectors {
let codes_array = UInt8Array::from(row_major_codes.to_vec());
let transposed =
transpose::<UInt8Type>(&codes_array, num_vectors, self.num_sub_vectors);
let _ = self
.primary
.append_transposed_batch(transposed.values(), positions);
} else if primary_remaining > 0 {
let primary_count = primary_remaining;
let primary_codes = &row_major_codes[..primary_count * self.num_sub_vectors];
let primary_positions = &positions[..primary_count];
let codes_array = UInt8Array::from(primary_codes.to_vec());
let transposed =
transpose::<UInt8Type>(&codes_array, primary_count, self.num_sub_vectors);
let _ = self
.primary
.append_transposed_batch(transposed.values(), primary_positions);
let overflow_count = num_vectors - primary_count;
for i in 0..overflow_count {
let idx = primary_count + i;
let code_start = idx * self.num_sub_vectors;
let code_end = code_start + self.num_sub_vectors;
let code = row_major_codes[code_start..code_end].to_vec();
self.overflow.insert(positions[idx], code);
}
self.overflow_count
.fetch_add(overflow_count, Ordering::Relaxed);
} else {
for (i, &pos) in positions.iter().enumerate() {
let code_start = i * self.num_sub_vectors;
let code_end = code_start + self.num_sub_vectors;
let code = row_major_codes[code_start..code_end].to_vec();
self.overflow.insert(pos, code);
}
self.overflow_count
.fetch_add(num_vectors, Ordering::Relaxed);
}
}
#[inline]
pub fn has_overflow(&self) -> bool {
self.overflow_count.load(Ordering::Relaxed) > 0
}
#[inline]
pub fn len(&self) -> usize {
self.primary.len() + self.overflow_count.load(Ordering::Relaxed)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get_primary_codes_for_search(&self) -> (Vec<u8>, Vec<u64>) {
self.primary.get_codes_for_search()
}
pub fn get_overflow_codes_for_search(&self) -> (Vec<u8>, Vec<u64>) {
let overflow_count = self.overflow_count.load(Ordering::Acquire);
if overflow_count == 0 {
return (Vec::new(), Vec::new());
}
let mut codes = Vec::with_capacity(overflow_count * self.num_sub_vectors);
let mut positions = Vec::with_capacity(overflow_count);
for entry in self.overflow.iter() {
positions.push(*entry.key());
codes.extend_from_slice(entry.value());
}
(codes, positions)
}
}
#[derive(Debug, Clone)]
pub struct IvfPqEntry {
pub row_position: RowPosition,
pub pq_code: Vec<u8>,
}
#[derive(Debug)]
pub struct IvfPqMemIndex {
field_id: i32,
column_name: String,
ivf_model: IvfModel,
pq: ProductQuantizer,
partitions: Vec<IvfPqMemPartition>,
vector_count: AtomicUsize,
distance_type: DistanceType,
num_partitions: usize,
code_len: usize,
}
const DEFAULT_PARTITION_CAPACITY: usize = 1024;
impl IvfPqMemIndex {
pub fn new(
field_id: i32,
column_name: String,
ivf_model: IvfModel,
pq: ProductQuantizer,
distance_type: DistanceType,
) -> Self {
Self::with_capacity(
field_id,
column_name,
ivf_model,
pq,
distance_type,
DEFAULT_PARTITION_CAPACITY,
)
}
pub fn with_capacity(
field_id: i32,
column_name: String,
ivf_model: IvfModel,
pq: ProductQuantizer,
distance_type: DistanceType,
partition_capacity: usize,
) -> Self {
let num_partitions = ivf_model.num_partitions();
let code_len = pq.num_sub_vectors * pq.num_bits as usize / 8;
let partitions: Vec<_> = (0..num_partitions)
.map(|_| IvfPqMemPartition::new(partition_capacity, code_len))
.collect();
Self {
field_id,
column_name,
ivf_model,
pq,
partitions,
vector_count: AtomicUsize::new(0),
distance_type,
num_partitions,
code_len,
}
}
pub fn field_id(&self) -> i32 {
self.field_id
}
pub fn insert(&self, batch: &RecordBatch, row_offset: u64) -> Result<()> {
let col_idx = batch
.schema()
.column_with_name(&self.column_name)
.map(|(idx, _)| idx);
let Some(col_idx) = col_idx else {
return Ok(());
};
let column = batch.column(col_idx);
let fsl = column.as_fixed_size_list_opt().ok_or_else(|| {
Error::invalid_input(format!(
"Column '{}' is not a FixedSizeList, got {:?}",
self.column_name,
column.data_type()
))
})?;
let centroids = self
.ivf_model
.centroids
.as_ref()
.ok_or_else(|| Error::invalid_input("IVF model has no centroids"))?;
let (partition_ids, _distances) =
compute_partitions_arrow_array(centroids, fsl, self.distance_type)?;
let pq_codes = self.pq.quantize(fsl)?;
let pq_codes_fsl = pq_codes.as_fixed_size_list();
let pq_codes_flat = pq_codes_fsl
.values()
.as_primitive::<arrow_array::types::UInt8Type>();
let mut partition_groups: Vec<Vec<usize>> = vec![Vec::new(); self.num_partitions];
for (row_idx, partition_id) in partition_ids.iter().enumerate().take(batch.num_rows()) {
if let Some(pid) = partition_id
&& (*pid as usize) < self.num_partitions
{
partition_groups[*pid as usize].push(row_idx);
}
}
let mut total_inserted = 0usize;
for (partition_id, indices) in partition_groups.iter().enumerate() {
if indices.is_empty() {
continue;
}
let num_vectors = indices.len();
let mut partition_codes: Vec<u8> = Vec::with_capacity(num_vectors * self.code_len);
let mut partition_positions: Vec<u64> = Vec::with_capacity(num_vectors);
for &row_idx in indices {
let code_start = row_idx * self.code_len;
let code_end = code_start + self.code_len;
partition_codes.extend_from_slice(&pq_codes_flat.values()[code_start..code_end]);
partition_positions.push(row_offset + row_idx as u64);
}
self.partitions[partition_id].append_batch(&partition_codes, &partition_positions);
total_inserted += num_vectors;
}
self.vector_count
.fetch_add(total_inserted, Ordering::Relaxed);
Ok(())
}
pub fn insert_batches(&self, batches: &[StoredBatch]) -> Result<()> {
if batches.is_empty() {
return Ok(());
}
let mut vector_arrays: Vec<&FixedSizeListArray> = Vec::with_capacity(batches.len());
let mut batch_infos: Vec<(u64, usize, usize)> = Vec::with_capacity(batches.len());
for stored in batches {
let col_idx = stored
.data
.schema()
.column_with_name(&self.column_name)
.map(|(idx, _)| idx);
if let Some(col_idx) = col_idx {
let column = stored.data.column(col_idx);
if let Some(fsl) = column.as_fixed_size_list_opt() {
let num_vectors = fsl.len();
if num_vectors > 0 {
vector_arrays.push(fsl);
batch_infos.push((stored.row_offset, num_vectors, stored.batch_position));
}
}
}
}
if vector_arrays.is_empty() {
return Ok(());
}
let arrays_as_refs: Vec<&dyn Array> =
vector_arrays.iter().map(|a| *a as &dyn Array).collect();
let concatenated = arrow_select::concat::concat(&arrays_as_refs)?;
let mega_fsl = concatenated.as_fixed_size_list();
let total_vectors = mega_fsl.len();
let centroids = self
.ivf_model
.centroids
.as_ref()
.ok_or_else(|| Error::invalid_input("IVF model has no centroids"))?;
let (partition_ids, _distances) =
compute_partitions_arrow_array(centroids, mega_fsl, self.distance_type)?;
let pq_codes = self.pq.quantize(mega_fsl)?;
let pq_codes_fsl = pq_codes.as_fixed_size_list();
let pq_codes_flat = pq_codes_fsl
.values()
.as_primitive::<arrow_array::types::UInt8Type>();
let mut row_positions: Vec<u64> = Vec::with_capacity(total_vectors);
for (row_offset, num_vectors, _) in &batch_infos {
for i in 0..*num_vectors {
row_positions.push(row_offset + i as u64);
}
}
let mut partition_groups: Vec<Vec<usize>> = vec![Vec::new(); self.num_partitions];
for (idx, pid) in partition_ids.iter().enumerate() {
if let Some(pid) = pid
&& (*pid as usize) < self.num_partitions
{
partition_groups[*pid as usize].push(idx);
}
}
let mut total_inserted = 0usize;
for (partition_id, indices) in partition_groups.iter().enumerate() {
if indices.is_empty() {
continue;
}
let num_vectors = indices.len();
let mut partition_codes: Vec<u8> = Vec::with_capacity(num_vectors * self.code_len);
let mut partition_positions: Vec<u64> = Vec::with_capacity(num_vectors);
for &idx in indices {
let code_start = idx * self.code_len;
let code_end = code_start + self.code_len;
partition_codes.extend_from_slice(&pq_codes_flat.values()[code_start..code_end]);
partition_positions.push(row_positions[idx]);
}
self.partitions[partition_id].append_batch(&partition_codes, &partition_positions);
total_inserted += num_vectors;
}
self.vector_count
.fetch_add(total_inserted, Ordering::Relaxed);
Ok(())
}
pub fn search(
&self,
query: &FixedSizeListArray,
k: usize,
nprobes: usize,
max_row_position: RowPosition,
) -> Result<Vec<(f32, RowPosition)>> {
if query.len() != 1 {
return Err(Error::invalid_input(format!(
"Query must have exactly 1 vector, got {}",
query.len()
)));
}
let query_values = query.value(0);
let (partition_ids, _) =
self.ivf_model
.find_partitions(&query_values, nprobes, self.distance_type)?;
let mut results: Vec<(f32, RowPosition)> = Vec::new();
for i in 0..partition_ids.len() {
let partition_id = partition_ids.value(i) as usize;
if partition_id >= self.num_partitions {
continue;
}
let partition = &self.partitions[partition_id];
if partition.is_empty() {
continue;
}
let (primary_codes, primary_positions) = partition.get_primary_codes_for_search();
if !primary_codes.is_empty() {
let codes_array = UInt8Array::from(primary_codes);
let distances = self.pq.compute_distances(&query_values, &codes_array)?;
for (idx, &dist) in distances.values().iter().enumerate() {
let pos = primary_positions[idx];
if pos <= max_row_position {
results.push((dist, pos));
}
}
}
if partition.has_overflow() {
let (overflow_codes_rowmajor, overflow_positions) =
partition.get_overflow_codes_for_search();
if !overflow_codes_rowmajor.is_empty() {
let num_overflow = overflow_positions.len();
let codes_array = UInt8Array::from(overflow_codes_rowmajor);
let transposed = transpose::<arrow_array::types::UInt8Type>(
&codes_array,
num_overflow,
self.code_len,
);
let distances = self.pq.compute_distances(&query_values, &transposed)?;
for (idx, &dist) in distances.values().iter().enumerate() {
let pos = overflow_positions[idx];
if pos <= max_row_position {
results.push((dist, pos));
}
}
}
}
}
results.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
Ok(results)
}
pub fn len(&self) -> usize {
self.vector_count.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.vector_count.load(Ordering::Relaxed) == 0
}
pub fn column_name(&self) -> &str {
&self.column_name
}
pub fn get_partition(&self, partition_id: usize) -> Vec<IvfPqEntry> {
if partition_id >= self.num_partitions {
return Vec::new();
}
let partition = &self.partitions[partition_id];
let mut entries = Vec::with_capacity(partition.len());
let (primary_codes, primary_positions) = partition.get_primary_codes_for_search();
if !primary_codes.is_empty() {
let num_vectors = primary_positions.len();
for (i, &row_position) in primary_positions.iter().enumerate() {
let mut pq_code = Vec::with_capacity(self.code_len);
for sv in 0..self.code_len {
pq_code.push(primary_codes[sv * num_vectors + i]);
}
entries.push(IvfPqEntry {
row_position,
pq_code,
});
}
}
let (overflow_codes, overflow_positions) = partition.get_overflow_codes_for_search();
for (i, &row_position) in overflow_positions.iter().enumerate() {
let code_start = i * self.code_len;
let code_end = code_start + self.code_len;
entries.push(IvfPqEntry {
row_position,
pq_code: overflow_codes[code_start..code_end].to_vec(),
});
}
entries
}
pub fn num_partitions(&self) -> usize {
self.ivf_model.num_partitions()
}
pub fn ivf_model(&self) -> &IvfModel {
&self.ivf_model
}
pub fn pq(&self) -> &ProductQuantizer {
&self.pq
}
pub fn distance_type(&self) -> DistanceType {
self.distance_type
}
pub fn to_partition_batches(&self) -> Result<Vec<(usize, RecordBatch)>> {
use arrow_array::UInt64Array;
use arrow_schema::{Field, Schema};
use lance_core::ROW_ID;
use lance_index::vector::PQ_CODE_COLUMN;
use std::sync::Arc;
let pq_code_len = self.pq.num_sub_vectors * self.pq.num_bits as usize / 8;
let schema = Arc::new(Schema::new(vec![
Field::new(ROW_ID, arrow_schema::DataType::UInt64, false),
Field::new(
PQ_CODE_COLUMN,
arrow_schema::DataType::FixedSizeList(
Arc::new(Field::new("item", arrow_schema::DataType::UInt8, false)),
pq_code_len as i32,
),
false,
),
]));
let mut result = Vec::new();
for part_id in 0..self.num_partitions {
let entries = self.get_partition(part_id);
if entries.is_empty() {
continue;
}
let row_ids: Vec<u64> = entries.iter().map(|e| e.row_position).collect();
let row_id_array = Arc::new(UInt64Array::from(row_ids));
let mut pq_codes_flat: Vec<u8> = Vec::with_capacity(entries.len() * pq_code_len);
for entry in &entries {
pq_codes_flat.extend_from_slice(&entry.pq_code);
}
let pq_codes_array = UInt8Array::from(pq_codes_flat);
let inner_field = Arc::new(Field::new("item", arrow_schema::DataType::UInt8, false));
let pq_codes_fsl = Arc::new(
FixedSizeListArray::try_new(
inner_field,
pq_code_len as i32,
Arc::new(pq_codes_array),
None,
)
.map_err(|e| Error::io(format!("Failed to create PQ code array: {}", e)))?,
);
let batch = RecordBatch::try_new(schema.clone(), vec![row_id_array, pq_codes_fsl])
.map_err(|e| Error::io(format!("Failed to create partition batch: {}", e)))?;
result.push((part_id, batch));
}
Ok(result)
}
pub fn to_partition_batches_reversed(
&self,
total_rows: usize,
) -> Result<Vec<(usize, RecordBatch)>> {
use arrow_array::UInt64Array;
use arrow_schema::{Field, Schema};
use lance_core::ROW_ID;
use lance_index::vector::PQ_CODE_COLUMN;
use std::sync::Arc;
let pq_code_len = self.pq.num_sub_vectors * self.pq.num_bits as usize / 8;
let total_rows_u64 = total_rows as u64;
let schema = Arc::new(Schema::new(vec![
Field::new(ROW_ID, arrow_schema::DataType::UInt64, false),
Field::new(
PQ_CODE_COLUMN,
arrow_schema::DataType::FixedSizeList(
Arc::new(Field::new("item", arrow_schema::DataType::UInt8, false)),
pq_code_len as i32,
),
false,
),
]));
let mut result = Vec::new();
for part_id in 0..self.num_partitions {
let entries = self.get_partition(part_id);
if entries.is_empty() {
continue;
}
let row_ids: Vec<u64> = entries
.iter()
.map(|e| total_rows_u64 - e.row_position - 1)
.collect();
let row_id_array = Arc::new(UInt64Array::from(row_ids));
let mut pq_codes_flat: Vec<u8> = Vec::with_capacity(entries.len() * pq_code_len);
for entry in &entries {
pq_codes_flat.extend_from_slice(&entry.pq_code);
}
let pq_codes_array = UInt8Array::from(pq_codes_flat);
let inner_field = Arc::new(Field::new("item", arrow_schema::DataType::UInt8, false));
let pq_codes_fsl = Arc::new(
FixedSizeListArray::try_new(
inner_field,
pq_code_len as i32,
Arc::new(pq_codes_array),
None,
)
.map_err(|e| Error::io(format!("Failed to create PQ code array: {}", e)))?,
);
let batch = RecordBatch::try_new(schema.clone(), vec![row_id_array, pq_codes_fsl])
.map_err(|e| Error::io(format!("Failed to create partition batch: {}", e)))?;
result.push((part_id, batch));
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct IvfPqIndexConfig {
pub name: String,
pub field_id: i32,
pub column: String,
pub ivf_model: IvfModel,
pub pq: ProductQuantizer,
pub distance_type: DistanceType,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_store_append_transposed() {
let store = ColumnMajorIvfPqMemPartition::new(100, 4);
let transposed_codes = vec![
10, 20, 30, 11, 21, 31, 12, 22, 32, 13, 23, 33,
];
let positions = vec![100, 200, 300];
store
.append_transposed_batch(&transposed_codes, &positions)
.unwrap();
assert_eq!(store.len(), 3);
assert_eq!(store.remaining_capacity(), 97);
let (codes, pos) = store.get_codes_for_search();
assert_eq!(pos, vec![100, 200, 300]);
assert_eq!(codes, transposed_codes);
}
#[test]
fn test_partition_store_full() {
let store = ColumnMajorIvfPqMemPartition::new(2, 4);
let codes1 = vec![1, 2, 3, 4, 5, 6, 7, 8]; let pos1 = vec![10, 20];
store.append_transposed_batch(&codes1, &pos1).unwrap();
assert_eq!(store.remaining_capacity(), 0);
let codes2 = vec![9, 10, 11, 12];
let pos2 = vec![30];
assert!(store.append_transposed_batch(&codes2, &pos2).is_err());
}
#[test]
fn test_ivfpq_partition_primary_only() {
let partition = IvfPqMemPartition::new(100, 4);
let row_major = vec![
10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, ];
let positions = vec![100, 200, 300];
partition.append_batch(&row_major, &positions);
assert_eq!(partition.len(), 3);
assert!(!partition.has_overflow());
let (codes, pos) = partition.get_primary_codes_for_search();
assert_eq!(pos, vec![100, 200, 300]);
assert_eq!(
codes,
vec![
10, 20, 30, 11, 21, 31, 12, 22, 32, 13, 23, 33, ]
);
}
#[test]
fn test_ivfpq_partition_overflow() {
let partition = IvfPqMemPartition::new(2, 4);
let row_major = vec![
10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, 40, 41, 42, 43, ];
let positions = vec![100, 200, 300, 400];
partition.append_batch(&row_major, &positions);
assert_eq!(partition.len(), 4);
assert!(partition.has_overflow());
let (primary_codes, primary_pos) = partition.get_primary_codes_for_search();
assert_eq!(primary_pos, vec![100, 200]);
assert_eq!(
primary_codes,
vec![
10, 20, 11, 21, 12, 22, 13, 23, ]
);
let (overflow_codes, overflow_pos) = partition.get_overflow_codes_for_search();
assert_eq!(overflow_pos.len(), 2);
assert!(overflow_pos.contains(&300));
assert!(overflow_pos.contains(&400));
assert_eq!(overflow_codes.len(), 8);
}
#[test]
fn test_ivfpq_partition_all_overflow() {
let partition = IvfPqMemPartition::new(2, 4);
let batch1 = vec![1, 2, 3, 4, 5, 6, 7, 8];
partition.append_batch(&batch1, &[10, 20]);
assert!(!partition.has_overflow());
let batch2 = vec![11, 12, 13, 14, 21, 22, 23, 24, 31, 32, 33, 34];
partition.append_batch(&batch2, &[30, 40, 50]);
assert_eq!(partition.len(), 5);
assert!(partition.has_overflow());
}
}