use crate::commands::query::QueryCommand;
use crate::models::Embedder;
use crate::retrieval::{merge_candidates, prune_candidates, RetrievalCandidate};
use crate::store::{self, build_retrieval_filter, connect_db, ensure_fts_index};
use anyhow::{Context, Result};
use arrow_array::{Float32Array, Float64Array, Int32Array, RecordBatch, StringArray};
use futures::TryStreamExt;
use lancedb::index::scalar::FullTextSearchQuery;
use lancedb::query::{QueryBase, QueryExecutionOptions};
use tracing::{field, Instrument};
use super::rerank::rerank_candidates;
use super::runtime::retrieval_limit;
use super::types::QueryRuntime;
pub(crate) async fn retrieve_candidates(
runtime: &QueryRuntime,
command: &QueryCommand,
queries: &[String],
trace: &mut Vec<String>,
) -> Result<Vec<RetrievalCandidate>> {
let span = tracing::info_span!(
"retrieve_candidates",
query_variants = queries.len(),
top_k = command.top_k,
fetch_k = command.fetch_k,
has_source_filter = command.source.is_some(),
has_path_prefix_filter = command.path_prefix.is_some(),
page = field::debug(command.page),
has_format_filter = command.format.is_some(),
merged_candidates = field::Empty,
pruned_candidates = field::Empty,
);
let span_inner = span.clone();
async move {
let mut groups = Vec::new();
for query in queries {
groups.push(retrieve_candidates_for_query(runtime, query, command).await?);
}
let merged = merge_candidates(groups);
span_inner.record("merged_candidates", merged.len());
trace.push(format!(
"merged {} candidate(s) across {} retrieval query variant(s)",
merged.len(),
queries.len()
));
let reranked = rerank_candidates(runtime, command, merged, trace).await;
let pruned = prune_candidates(reranked, command.top_k);
span_inner.record("pruned_candidates", pruned.len());
trace.push(format!("kept {} candidate(s) after pruning", pruned.len()));
Ok(pruned)
}
.instrument(span)
.await
}
async fn retrieve_candidates_for_query(
runtime: &QueryRuntime,
question: &str,
command: &QueryCommand,
) -> Result<Vec<RetrievalCandidate>> {
let span = tracing::info_span!(
"retrieve_query_variant",
query_chars = question.chars().count(),
retrieval_limit = retrieval_limit(command),
has_source_filter = command.source.is_some(),
has_path_prefix_filter = command.path_prefix.is_some(),
page = field::debug(command.page),
has_format_filter = command.format.is_some(),
batch_count = field::Empty,
hit_count = field::Empty,
);
let span_inner = span.clone();
async move {
let db = connect_db(&runtime.store).await?;
let table = db
.open_table(store::DEFAULT_TABLE_NAME)
.execute()
.await
.context("open table")?;
ensure_fts_index(&table, false).await?;
let embedder = Embedder::new(
runtime.cfg.ollama.base_url.clone(),
runtime.embed_model_name.clone(),
);
let embedding = embedder.embed(question).await?;
let mut query = table
.query()
.full_text_search(FullTextSearchQuery::new(question.to_string()))
.nearest_to(embedding.as_slice())?
.limit(retrieval_limit(command));
if let Some(filter) = build_retrieval_filter(
command.source.as_deref(),
command.path_prefix.as_deref(),
command.page,
command.format.as_deref(),
) {
query = query.only_if(filter);
}
let batches: Vec<RecordBatch> = query
.execute_hybrid(QueryExecutionOptions::default())
.await?
.try_collect::<Vec<_>>()
.await?;
span_inner.record("batch_count", batches.len());
let hits = extract_candidates(&batches)?;
span_inner.record("hit_count", hits.len());
Ok(hits)
}
.instrument(span)
.await
}
fn extract_candidates(batches: &[RecordBatch]) -> Result<Vec<RetrievalCandidate>> {
let mut hits = Vec::new();
for batch in batches {
let id_col = batch
.column_by_name("id")
.and_then(|column| column.as_any().downcast_ref::<StringArray>());
let text_col = batch
.column_by_name("chunk_text")
.context("chunk_text column missing")?
.as_any()
.downcast_ref::<StringArray>()
.context("chunk_text column type")?;
let source_col = batch
.column_by_name("source_path")
.context("source_path column missing")?
.as_any()
.downcast_ref::<StringArray>()
.context("source_path column type")?;
let metadata_col = batch
.column_by_name("metadata")
.and_then(|column| column.as_any().downcast_ref::<StringArray>());
let page_col = batch
.column_by_name("page")
.and_then(|column| column.as_any().downcast_ref::<Int32Array>());
let chunk_index_col = batch
.column_by_name("chunk_index")
.and_then(|column| column.as_any().downcast_ref::<Int32Array>());
for row in 0..batch.num_rows() {
let fused_score = hybrid_relevance_score_at(batch, row);
let vector_score = vector_score_at(batch, row, fused_score.is_some());
let keyword_score = if fused_score.is_some() {
None
} else {
raw_score_at(batch, row)
};
hits.push(RetrievalCandidate {
id: id_col
.map(|column| column.value(row).to_string())
.unwrap_or_default(),
source_path: source_col.value(row).to_string(),
chunk_text: text_col.value(row).to_string(),
metadata: metadata_col
.map(|column| column.value(row).to_string())
.unwrap_or_default(),
page: page_col.map(|column| column.value(row)).unwrap_or_default(),
chunk_index: chunk_index_col
.map(|column| column.value(row))
.unwrap_or_default(),
vector_score,
keyword_score,
fused_score: fused_score.or(vector_score).or(keyword_score),
rerank_score: None,
});
}
}
Ok(hits)
}
fn hybrid_relevance_score_at(batch: &RecordBatch, row: usize) -> Option<f32> {
numeric_column_value(batch, row, "_relevance_score")
}
fn raw_score_at(batch: &RecordBatch, row: usize) -> Option<f32> {
["_score", "score"]
.into_iter()
.find_map(|name| numeric_column_value(batch, row, name))
}
fn vector_score_at(batch: &RecordBatch, row: usize, has_hybrid_relevance: bool) -> Option<f32> {
for name in ["_distance", "distance"] {
if let Some(distance) = numeric_column_value(batch, row, name) {
return Some(distance_to_similarity(distance));
}
}
if has_hybrid_relevance {
return None;
}
raw_score_at(batch, row)
}
fn numeric_column_value(batch: &RecordBatch, row: usize, name: &str) -> Option<f32> {
let column = batch.column_by_name(name)?;
if let Some(values) = column.as_any().downcast_ref::<Float32Array>() {
return Some(values.value(row));
}
if let Some(values) = column.as_any().downcast_ref::<Float64Array>() {
return Some(values.value(row) as f32);
}
None
}
fn distance_to_similarity(distance: f32) -> f32 {
1.0 / (1.0 + distance.max(0.0))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Float32Array, Int32Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
#[test]
fn test_distance_to_similarity_makes_smaller_distances_score_higher() {
assert!(distance_to_similarity(0.2) > distance_to_similarity(0.8));
}
#[test]
fn test_extract_candidates_prefers_hybrid_relevance_over_raw_score() {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("source_path", DataType::Utf8, false),
Field::new("chunk_text", DataType::Utf8, false),
Field::new("page", DataType::Int32, false),
Field::new("chunk_index", DataType::Int32, false),
Field::new("_score", DataType::Float32, true),
Field::new("_relevance_score", DataType::Float32, false),
])),
vec![
Arc::new(StringArray::from_iter_values(["docs/totoro.md"])),
Arc::new(StringArray::from_iter_values([
"Chibi Totoro and Chu Totoro",
])),
Arc::new(Int32Array::from_iter_values([3])),
Arc::new(Int32Array::from_iter_values([0])),
Arc::new(Float32Array::from_iter_values([0.000001])),
Arc::new(Float32Array::from_iter_values([0.032795697])),
],
)
.unwrap();
let candidates = extract_candidates(&[batch]).unwrap();
assert_eq!(candidates[0].fused_score, Some(0.032795697));
assert_eq!(candidates[0].keyword_score, None);
assert_eq!(candidates[0].vector_score, None);
}
}