use std::sync::Arc;
use std::{collections::BTreeMap, ops::Range};
use arrow_array::types::UInt32Type;
use arrow_array::{cast::AsArray, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use futures::{stream::repeat_with, StreamExt};
use lance_arrow::RecordBatchExt;
use lance_core::{io::Writer, ROW_ID, ROW_ID_FIELD};
use lance_index::vector::pq::ProductQuantizer;
use lance_index::vector::{PART_ID_COLUMN, PQ_CODE_COLUMN};
use lance_linalg::distance::MetricType;
use snafu::{location, Location};
use tracing::instrument;
use crate::index::vector::ivf::{
io::write_index_partitions,
shuffler::{Shuffler, ShufflerBuilder},
Ivf,
};
use crate::{io::RecordBatchStream, Error, Result};
pub async fn shuffle_dataset(
data: impl RecordBatchStream + Unpin,
column: &str,
ivf: Arc<dyn lance_index::vector::ivf::Ivf>,
num_sub_vectors: usize,
) -> Result<Shuffler> {
let mut stream = data
.zip(repeat_with(|| ivf.clone()))
.map(|(b, ivf)| async move {
let batch = b?;
ivf.partition_transform(&batch, column).await
})
.buffer_unordered(num_cpus::get() * 2)
.map(|batch| async move {
let batch = batch?;
tokio::task::spawn_blocking(move || {
let part_id = batch
.column_by_name(PART_ID_COLUMN)
.expect("The caller already checked column exist");
let part_id_arr = part_id.as_primitive::<UInt32Type>();
let mut cnt_map = BTreeMap::<u32, Vec<u32>>::new();
for (idx, part_id) in part_id_arr.values().iter().enumerate() {
cnt_map.entry(*part_id).or_default().push(idx as u32);
}
cnt_map
.into_iter()
.map(|(part_id, row_ids)| {
let indices = UInt32Array::from(row_ids);
let batch = batch.take(&indices)?;
Ok((part_id, batch))
})
.collect::<Result<Vec<_>>>()
})
.await
.map_err(|e| Error::Index {
message: e.to_string(),
location: location!(),
})
})
.buffer_unordered(num_cpus::get())
.boxed();
let schema = Schema::new(vec![
ROW_ID_FIELD.clone(),
Field::new(PART_ID_COLUMN, DataType::UInt32, false),
Field::new(
PQ_CODE_COLUMN,
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::UInt8, true)),
num_sub_vectors as i32,
),
false,
),
]);
const FLUSH_THRESHOLD: usize = 40 * 1024;
let mut shuffler_builder = ShufflerBuilder::try_new(&schema, FLUSH_THRESHOLD).await?;
while let Some(result) = stream.next().await {
let batches = result??;
if batches.is_empty() {
continue;
}
for (part_id, batch) in batches {
shuffler_builder.insert(part_id, batch).await?;
}
}
shuffler_builder.finish().await
}
#[instrument(level = "debug", skip(writer, data, ivf, pq))]
pub(super) async fn build_partitions(
writer: &mut dyn Writer,
data: impl RecordBatchStream + Unpin,
column: &str,
ivf: &mut Ivf,
pq: Arc<dyn ProductQuantizer>,
metric_type: MetricType,
part_range: Range<u32>,
) -> Result<()> {
let schema = data.schema();
if schema.column_with_name(column).is_none() {
return Err(Error::Schema {
message: format!("column {} does not exist in data stream", column),
location: location!(),
});
}
if schema.column_with_name(ROW_ID).is_none() {
return Err(Error::Schema {
message: "ROW ID is not set when building index partitions".to_string(),
location: location!(),
});
}
let ivf_model = lance_index::vector::ivf::new_ivf_with_pq(
ivf.centroids.values(),
ivf.centroids.value_length() as usize,
metric_type,
column,
pq.clone(),
Some(part_range),
)?;
let shuffler = shuffle_dataset(data, column, ivf_model, pq.num_sub_vectors()).await?;
write_index_partitions(writer, ivf, &shuffler, None).await?;
Ok(())
}