use arrow::array::{Array, StringArray};
use arrow::compute::cast;
use arrow::datatypes::DataType;
use datafusion::prelude::SessionContext;
use jammi_numerics::distance::cosine_distance;
use crate::error::{JammiError, Result};
use crate::store::vectors::extend_with_fixed_size_list_f32;
pub async fn exact_vector_search(
ctx: &SessionContext,
table_name: &str,
query: &[f32],
k: usize,
) -> Result<Vec<(String, f32)>> {
let df = ctx
.sql(&format!(
"SELECT _row_id, vector FROM \"jammi.{table_name}\""
))
.await?;
let batches = df.collect().await?;
let mut scored: Vec<(String, f32)> = Vec::new();
let mut vectors: Vec<Vec<f32>> = Vec::new();
for batch in &batches {
let row_ids_col = batch
.column_by_name("_row_id")
.ok_or_else(|| JammiError::Other("Missing _row_id in exact search".into()))?;
let row_ids_utf8 = cast(row_ids_col, &DataType::Utf8).map_err(|e| {
JammiError::Other(format!("_row_id column could not be cast to Utf8: {e}"))
})?;
let row_ids = row_ids_utf8
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
JammiError::Other("_row_id column is not a Utf8-castable string type".into())
})?;
let before = vectors.len();
extend_with_fixed_size_list_f32(batch, table_name, "vector", &mut vectors)?;
for (offset, vec) in vectors[before..].iter().enumerate() {
let dist = cosine_distance(query, vec);
scored.push((row_ids.value(offset).to_string(), dist));
}
}
scored.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
scored.truncate(k);
Ok(scored)
}