use std::{
any::Any,
cmp::Reverse,
collections::{BTreeMap, BinaryHeap, HashSet},
sync::Arc,
};
use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use async_trait::async_trait;
use object_store::path::Path;
use ordered_float::OrderedFloat;
use super::row_vertex::{RowVertex, RowVertexSerDe};
use crate::{
dataset::{Dataset, ROW_ID},
index::{
vector::{
graph::{GraphReadParams, PersistedGraph},
SCORE_COL,
},
Index,
},
io::deletion::LruDeletionVectorStore,
Result,
};
use crate::{
index::{
vector::VectorIndex,
vector::{
graph::{Graph, VertexWithDistance},
Query,
},
},
io::object_reader::ObjectReader,
Error,
};
pub struct SearchState {
pub visited: HashSet<usize>,
candidates: BTreeMap<OrderedFloat<f32>, usize>,
heap: BinaryHeap<Reverse<VertexWithDistance>>,
heap_visisted: HashSet<usize>,
l: usize,
#[allow(dead_code)]
k: usize,
}
impl SearchState {
pub(crate) fn new(k: usize, l: usize) -> Self {
Self {
visited: HashSet::new(),
candidates: BTreeMap::new(),
heap: BinaryHeap::new(),
heap_visisted: HashSet::new(),
k,
l,
}
}
fn pop(&mut self) -> Option<usize> {
while let Some(vertex) = self.heap.pop() {
if !self.candidates.contains_key(&vertex.0.distance) {
continue;
}
self.visited.insert(vertex.0.id);
return Some(vertex.0.id);
}
None
}
fn push(&mut self, vertex_id: usize, distance: f32) {
assert!(!self.visited.contains(&vertex_id));
self.heap_visisted.insert(vertex_id);
self.heap
.push(Reverse(VertexWithDistance::new(vertex_id, distance)));
self.candidates.insert(OrderedFloat(distance), vertex_id);
if self.candidates.len() > self.l {
self.candidates.pop_last();
}
}
fn visit(&mut self, vertex_id: usize) {
self.visited.insert(vertex_id);
}
fn is_visited(&self, vertex_id: usize) -> bool {
self.visited.contains(&vertex_id) || self.heap_visisted.contains(&vertex_id)
}
}
pub async fn greedy_search(
graph: &(dyn Graph + Send + Sync),
start: usize,
query: &[f32],
k: usize,
search_size: usize, ) -> Result<SearchState> {
let mut state = SearchState::new(k, search_size);
let dist = graph.distance_to(query, start).await?;
state.push(start, dist);
while let Some(id) = state.pop() {
state.visit(id);
let neighbors = graph.neighbors(id).await?;
for neighbor_id in neighbors.values() {
let neighbor_id = *neighbor_id as usize;
if state.is_visited(neighbor_id) {
continue;
}
let dist = graph.distance_to(query, neighbor_id).await?;
state.push(neighbor_id, dist);
}
}
Ok(state)
}
pub struct DiskANNIndex {
graph: PersistedGraph<RowVertex>,
deletion_cache: Arc<LruDeletionVectorStore>,
}
impl std::fmt::Debug for DiskANNIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "DiskANNIndex")
}
}
impl DiskANNIndex {
pub async fn try_new(
dataset: Arc<Dataset>,
index_column: &str,
graph_path: &Path,
deletion_cache: Arc<LruDeletionVectorStore>,
) -> Result<Self> {
let params = GraphReadParams::default();
let serde = Arc::new(RowVertexSerDe::new());
let graph =
PersistedGraph::try_new(dataset, index_column, graph_path, params, serde).await?;
Ok(Self {
graph,
deletion_cache,
})
}
}
impl Index for DiskANNIndex {
fn as_any(&self) -> &dyn Any {
self
}
}
#[async_trait]
impl VectorIndex for DiskANNIndex {
async fn search(&self, query: &Query) -> Result<RecordBatch> {
let state = greedy_search(&self.graph, 0, query.key.values(), query.k, query.k * 2).await?;
let schema = Arc::new(Schema::new(vec![
Field::new(ROW_ID, DataType::UInt64, false),
Field::new(SCORE_COL, DataType::Float32, false),
]));
let mut candidates = Vec::with_capacity(query.k);
for (score, row) in state.candidates {
if candidates.len() == query.k {
break;
}
if !self.deletion_cache.as_ref().is_deleted(row as u64).await? {
candidates.push((score, row));
}
}
let row_ids: UInt64Array = candidates
.iter()
.take(query.k)
.map(|(_, id)| *id as u64)
.collect();
let scores: Float32Array = candidates.iter().take(query.k).map(|(d, _)| **d).collect();
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(row_ids) as ArrayRef, Arc::new(scores) as ArrayRef],
)?;
Ok(batch)
}
fn is_loadable(&self) -> bool {
false
}
async fn load(
&self,
_reader: &dyn ObjectReader,
_offset: usize,
_length: usize,
) -> Result<Arc<dyn VectorIndex>> {
Err(Error::Index {
message: "DiskANNIndex is not loadable".to_string(),
})
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_search_state() {
let k: usize = 10;
let l: usize = 20;
let mut state = SearchState::new(k, l);
for i in (0..40).rev() {
state.push(i, i as f32);
}
assert_eq!(state.visited.len(), 0);
assert_eq!(state.heap.len(), 40);
assert_eq!(state.candidates.len(), 20);
let mut i = 0;
while let Some(next) = state.pop() {
state.visited.insert(next);
assert_eq!(next, i);
i += 1;
}
assert_eq!(i, 20);
assert!(state.heap.is_empty());
assert_eq!(state.candidates.len(), 20);
}
}