use std::sync::Arc;
use super::backend::RetrievalBackend;
use super::document::DocumentMeta;
use super::exact_backend::ExactCosineBackend;
use super::index::RagIndex;
use super::Result;
use crate::neural::ModernBertEmbedder;
#[derive(Clone, Debug)]
pub struct RetrievalConfig {
pub top_k: usize,
pub min_similarity: f32,
pub include_explicit_synopsis: bool,
pub include_generated_synopsis: bool,
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
top_k: 10,
min_similarity: 0.0,
include_explicit_synopsis: true,
include_generated_synopsis: true,
}
}
}
#[derive(Clone, Debug)]
pub struct RetrievalResult {
pub uri: String,
pub title: Option<String>,
pub synopsis: String,
pub synopsis_is_explicit: bool,
pub score: f32,
pub rank: usize,
}
impl RetrievalResult {
pub fn from_meta(meta: &DocumentMeta, score: f32, rank: usize) -> Self {
Self {
uri: meta.uri.clone(),
title: meta.title.clone(),
synopsis: meta.synopsis.clone(),
synopsis_is_explicit: matches!(
meta.synopsis_source,
crate::neural::SynopsisSource::Explicit
),
score,
rank,
}
}
pub fn display_title(&self) -> &str {
self.title.as_deref().unwrap_or(&self.uri)
}
}
pub struct Retriever<B: RetrievalBackend = ExactCosineBackend> {
index: Arc<RagIndex<B>>,
embedder: ModernBertEmbedder,
config: RetrievalConfig,
}
impl<B: RetrievalBackend> Retriever<B> {
pub fn new(
index: Arc<RagIndex<B>>,
embedder: ModernBertEmbedder,
config: RetrievalConfig,
) -> Self {
Self {
index,
embedder,
config,
}
}
pub fn query(&mut self, query: &str) -> Result<Vec<RetrievalResult>> {
let embedding = self.embedder.embed_query(query)?;
self.query_with_embedding(&embedding)
}
pub fn query_with_embedding(&self, embedding: &[f32]) -> Result<Vec<RetrievalResult>> {
let raw_results = self.index.query(embedding, self.config.top_k);
let results: Vec<RetrievalResult> = raw_results
.into_iter()
.enumerate()
.filter(|(_, (meta, score))| {
if *score < self.config.min_similarity {
return false;
}
let is_explicit = matches!(
meta.synopsis_source,
crate::neural::SynopsisSource::Explicit
);
if is_explicit && !self.config.include_explicit_synopsis {
return false;
}
if !is_explicit && !self.config.include_generated_synopsis {
return false;
}
true
})
.map(|(i, (meta, score))| RetrievalResult::from_meta(&meta, score, i + 1))
.collect();
Ok(results)
}
pub fn config(&self) -> &RetrievalConfig {
&self.config
}
pub fn set_config(&mut self, config: RetrievalConfig) {
self.config = config;
}
pub fn index(&self) -> &RagIndex<B> {
&self.index
}
pub fn embedder(&self) -> &ModernBertEmbedder {
&self.embedder
}
pub fn embedder_mut(&mut self) -> &mut ModernBertEmbedder {
&mut self.embedder
}
}
impl<B: RetrievalBackend> std::fmt::Debug for Retriever<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Retriever")
.field("index_size", &self.index.len())
.field("config", &self.config)
.finish()
}
}
pub struct BatchRetriever<B: RetrievalBackend = ExactCosineBackend> {
retriever: Retriever<B>,
}
impl<B: RetrievalBackend> BatchRetriever<B> {
pub fn new(retriever: Retriever<B>) -> Self {
Self { retriever }
}
pub fn query_batch(&mut self, queries: &[&str]) -> Result<Vec<Vec<RetrievalResult>>> {
queries.iter().map(|q| self.retriever.query(q)).collect()
}
pub fn retriever(&self) -> &Retriever<B> {
&self.retriever
}
pub fn retriever_mut(&mut self) -> &mut Retriever<B> {
&mut self.retriever
}
}
pub fn format_results(results: &[RetrievalResult]) -> String {
let mut output = String::new();
for result in results {
output.push_str(&format!(
"{}. [{:.2}] {}\n",
result.rank,
result.score,
result.display_title()
));
output.push_str(&format!(" URI: {}\n", result.uri));
let synopsis_type = if result.synopsis_is_explicit {
"explicit"
} else {
"generated"
};
output.push_str(&format!(
" Synopsis ({}): {}\n",
synopsis_type, result.synopsis
));
output.push('\n');
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retrieval_result() {
use super::super::document::{DocumentMeta, DocumentMetadata, LanguageTag};
use crate::neural::SynopsisSource;
let meta = DocumentMeta {
uri: "test://doc".to_string(),
title: Some("Test Document".to_string()),
synopsis: "This is a test.".to_string(),
synopsis_source: SynopsisSource::Explicit,
language: LanguageTag::english_us(),
metadata: DocumentMetadata::default(),
topic_ids: Vec::new(),
};
let result = RetrievalResult::from_meta(&meta, 0.95, 1);
assert_eq!(result.uri, "test://doc");
assert_eq!(result.display_title(), "Test Document");
assert!(result.synopsis_is_explicit);
assert!((result.score - 0.95).abs() < 1e-6);
}
#[test]
fn test_format_results() {
let results = vec![
RetrievalResult {
uri: "test://1".to_string(),
title: Some("First".to_string()),
synopsis: "First doc.".to_string(),
synopsis_is_explicit: true,
score: 0.95,
rank: 1,
},
RetrievalResult {
uri: "test://2".to_string(),
title: None,
synopsis: "Second doc.".to_string(),
synopsis_is_explicit: false,
score: 0.80,
rank: 2,
},
];
let output = format_results(&results);
assert!(output.contains("First"));
assert!(output.contains("0.95"));
assert!(output.contains("explicit"));
assert!(output.contains("generated"));
}
}