use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;
use async_trait::async_trait;
use crate::error::RetrieverError;
use crate::language_models::llm::LLM;
use crate::schemas::{Document, Retriever};
#[derive(Debug, Clone)]
pub struct MultiQueryRetrieverConfig {
pub num_queries: usize,
pub prompt_template: Option<String>,
}
impl Default for MultiQueryRetrieverConfig {
fn default() -> Self {
Self {
num_queries: 3,
prompt_template: None,
}
}
}
pub struct MultiQueryRetriever {
base_retriever: Arc<dyn Retriever>,
llm: Arc<dyn LLM>,
config: MultiQueryRetrieverConfig,
}
impl MultiQueryRetriever {
pub fn new(base_retriever: Arc<dyn Retriever>, llm: Arc<dyn LLM>) -> Self {
Self::with_config(base_retriever, llm, MultiQueryRetrieverConfig::default())
}
pub fn with_config(
base_retriever: Arc<dyn Retriever>,
llm: Arc<dyn LLM>,
config: MultiQueryRetrieverConfig,
) -> Self {
Self {
base_retriever,
llm,
config,
}
}
async fn generate_queries(&self, original_query: &str) -> Result<Vec<String>, Box<dyn Error>> {
let prompt = self.config.prompt_template.as_ref().map(|t| {
t.replace("{query}", original_query)
.replace("{num_queries}", &self.config.num_queries.to_string())
}).unwrap_or_else(|| {
format!(
"You are an AI language model assistant. Your task is to generate {} different versions \
of the given user question to retrieve relevant documents from a vector database. \
By generating multiple perspectives on the user question, your goal is to help \
the user overcome some of the limitations of distance-based similarity search. \
Provide these alternative questions separated by newlines.\n\n\
Original question: {}\n\n\
Alternative questions:",
self.config.num_queries,
original_query
)
});
let messages = vec![crate::schemas::messages::Message::new_human_message(
&prompt,
)];
let result = self.llm.generate(&messages).await?;
let queries: Vec<String> = result
.generation
.lines()
.map(|line| line.trim().to_string())
.filter(|line| !line.is_empty())
.take(self.config.num_queries)
.collect();
let mut all_queries = vec![original_query.to_string()];
all_queries.extend(queries);
Ok(all_queries)
}
fn merge_results(&self, all_results: Vec<Vec<Document>>) -> Vec<Document> {
let mut seen = HashMap::new();
let mut merged = Vec::new();
for results in all_results {
for doc in results {
let key = doc.page_content.clone();
if !seen.contains_key(&key) {
seen.insert(key.clone(), true);
merged.push(doc);
}
}
}
merged
}
}
#[async_trait]
impl Retriever for MultiQueryRetriever {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>, RetrieverError> {
let queries = self
.generate_queries(query)
.await
.map_err(|e| RetrieverError::DocumentProcessingError(e.to_string()))?;
let mut all_results = Vec::new();
for q in queries {
match self.base_retriever.get_relevant_documents(&q).await {
Ok(results) => all_results.push(results),
Err(_e) => {
log::warn!(
"Failed to retrieve documents for generated query (length: {})",
q.len()
);
all_results.push(Vec::new());
}
}
}
let merged = self.merge_results(all_results);
Ok(merged)
}
}