use std::cmp::Ordering;
use std::collections::BinaryHeap;
use arrow::array::{Array, StringArray};
use arrow::compute::cast;
use arrow::datatypes::DataType;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use jammi_numerics::distance::cosine_distance;
use crate::error::{JammiError, Result};
use crate::store::vectors::extend_with_fixed_size_list_f32;
fn candidate_order(a: &(String, f32), b: &(String, f32)) -> Ordering {
a.1.partial_cmp(&b.1)
.unwrap_or(Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
}
struct Candidate((String, f32));
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
candidate_order(&self.0, &other.0) == Ordering::Equal
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
candidate_order(&self.0, &other.0)
}
}
struct BoundedTopK {
k: usize,
heap: BinaryHeap<Candidate>,
}
impl BoundedTopK {
fn new(k: usize) -> Self {
Self {
k,
heap: BinaryHeap::with_capacity(k),
}
}
fn offer(&mut self, row_id: String, dist: f32) {
if self.k == 0 {
return;
}
let candidate = Candidate((row_id, dist));
if self.heap.len() < self.k {
self.heap.push(candidate);
} else if let Some(worst) = self.heap.peek() {
if candidate.cmp(worst) == Ordering::Less {
self.heap.pop();
self.heap.push(candidate);
}
}
}
fn into_sorted(self) -> Vec<(String, f32)> {
let mut out: Vec<(String, f32)> = self.heap.into_iter().map(|c| c.0).collect();
out.sort_by(candidate_order);
out
}
}
pub async fn exact_vector_search(
ctx: &SessionContext,
table_name: &str,
query: &[f32],
k: usize,
) -> Result<Vec<(String, f32)>> {
let df = ctx
.sql(&format!(
"SELECT _row_id, vector FROM \"jammi.{table_name}\""
))
.await?;
let mut stream = df.execute_stream().await?;
let mut top_k = BoundedTopK::new(k);
let mut vectors: Vec<Vec<f32>> = Vec::new();
while let Some(batch) = stream.try_next().await? {
let row_ids_col = batch
.column_by_name("_row_id")
.ok_or_else(|| JammiError::Other("Missing _row_id in exact search".into()))?;
let row_ids_utf8 = cast(row_ids_col, &DataType::Utf8).map_err(|e| {
JammiError::Other(format!("_row_id column could not be cast to Utf8: {e}"))
})?;
let row_ids = row_ids_utf8
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
JammiError::Other("_row_id column is not a Utf8-castable string type".into())
})?;
vectors.clear();
extend_with_fixed_size_list_f32(&batch, table_name, "vector", &mut vectors)?;
for (offset, vec) in vectors.iter().enumerate() {
let dist = cosine_distance(query, vec);
top_k.offer(row_ids.value(offset).to_string(), dist);
}
}
Ok(top_k.into_sorted())
}