use std::sync::Arc;
use async_trait::async_trait;
use crate::{
chain::{Chain, ConversationalRetrieverChainBuilder},
error::RetrieverError,
language_models::llm::LLM,
memory::SimpleMemory,
prompt::PromptArgs,
schemas::{BaseMemory, Document, Retriever},
};
struct RetrieverWrapper(Arc<dyn Retriever>);
#[async_trait]
impl Retriever for RetrieverWrapper {
async fn get_relevant_documents(&self, query: &str) -> Result<Vec<Document>, RetrieverError> {
self.0.get_relevant_documents(query).await
}
}
use super::{
answer_validator::AnswerValidator, query_enhancer::QueryEnhancer,
retrieval_validator::RetrievalValidator,
};
use crate::rag::RAGError;
pub struct HybridRAGConfig {
pub max_retrieval_retries: usize,
pub max_generation_retries: usize,
pub enable_query_enhancement: bool,
pub enable_retrieval_validation: bool,
pub enable_answer_validation: bool,
}
impl Default for HybridRAGConfig {
fn default() -> Self {
Self {
max_retrieval_retries: 2,
max_generation_retries: 2,
enable_query_enhancement: true,
enable_retrieval_validation: true,
enable_answer_validation: true,
}
}
}
pub struct HybridRAG {
retriever: Arc<dyn Retriever>,
llm: Box<dyn LLM>,
memory: Arc<tokio::sync::Mutex<dyn BaseMemory>>,
query_enhancer: Option<Arc<dyn QueryEnhancer>>,
retrieval_validator: Option<Arc<dyn RetrievalValidator>>,
answer_validator: Option<Arc<dyn AnswerValidator>>,
config: HybridRAGConfig,
}
impl HybridRAG {
pub fn new(
retriever: Arc<dyn Retriever>,
llm: Box<dyn LLM>,
memory: Arc<tokio::sync::Mutex<dyn BaseMemory>>,
) -> Self {
Self {
retriever,
llm,
memory,
query_enhancer: None,
retrieval_validator: None,
answer_validator: None,
config: HybridRAGConfig::default(),
}
}
pub fn with_query_enhancer(mut self, enhancer: Arc<dyn QueryEnhancer>) -> Self {
self.query_enhancer = Some(enhancer);
self
}
pub fn with_retrieval_validator(mut self, validator: Arc<dyn RetrievalValidator>) -> Self {
self.retrieval_validator = Some(validator);
self
}
pub fn with_answer_validator(mut self, validator: Arc<dyn AnswerValidator>) -> Self {
self.answer_validator = Some(validator);
self
}
pub fn with_config(mut self, config: HybridRAGConfig) -> Self {
self.config = config;
self
}
pub async fn invoke(&self, query: &str) -> Result<String, RAGError> {
let mut current_query = query.to_string();
let mut documents;
if self.config.enable_query_enhancement {
if let Some(enhancer) = &self.query_enhancer {
let enhanced = enhancer.enhance(¤t_query).await?;
current_query = enhanced.query;
}
}
let mut retrieval_attempts = 0;
loop {
documents = self
.retriever
.get_relevant_documents(¤t_query)
.await
.map_err(|e| RAGError::RetrieverError(e.to_string()))?;
if self.config.enable_retrieval_validation {
if let Some(validator) = &self.retrieval_validator {
let validation = validator.validate(¤t_query, &documents).await?;
if !validation.is_valid
&& retrieval_attempts < self.config.max_retrieval_retries
{
if let Some(suggestion) = validation.suggestions.first() {
current_query = format!("{} {}", current_query, suggestion);
}
retrieval_attempts += 1;
continue;
} else if !validation.is_valid {
log::warn!(
"Retrieval validation failed but max retries reached: {:?}",
validation.feedback
);
}
}
}
break;
}
let mut generation_attempts = 0;
let mut answer;
loop {
let retriever_box: Box<dyn Retriever> =
Box::new(RetrieverWrapper(self.retriever.clone()));
let chain = ConversationalRetrieverChainBuilder::new()
.llm(self.llm.clone_box())
.retriever(retriever_box)
.memory(Arc::clone(&self.memory))
.rephrase_question(false) .build()
.map_err(|e| RAGError::ChainError(e))?;
let mut prompt_args = PromptArgs::new();
prompt_args.insert("question".to_string(), serde_json::json!(current_query));
answer = chain
.invoke(prompt_args)
.await
.map_err(|e| RAGError::ChainError(e))?;
if self.config.enable_answer_validation {
if let Some(validator) = &self.answer_validator {
let validation = validator
.validate(¤t_query, &answer, &documents)
.await?;
if !validation.is_valid
&& generation_attempts < self.config.max_generation_retries
{
if let Some(suggestion) = validation.suggestions.first() {
current_query = format!("{} {}", current_query, suggestion);
}
generation_attempts += 1;
continue;
} else if !validation.is_valid {
log::warn!(
"Answer validation failed but max retries reached: {:?}",
validation.feedback
);
}
}
}
break;
}
Ok(answer)
}
}
pub struct HybridRAGBuilder {
retriever: Option<Arc<dyn Retriever>>,
llm: Option<Box<dyn LLM>>,
memory: Option<Arc<tokio::sync::Mutex<dyn BaseMemory>>>,
query_enhancer: Option<Arc<dyn QueryEnhancer>>,
retrieval_validator: Option<Arc<dyn RetrievalValidator>>,
answer_validator: Option<Arc<dyn AnswerValidator>>,
config: HybridRAGConfig,
}
impl HybridRAGBuilder {
pub fn new() -> Self {
Self {
retriever: None,
llm: None,
memory: None,
query_enhancer: None,
retrieval_validator: None,
answer_validator: None,
config: HybridRAGConfig::default(),
}
}
pub fn with_retriever(mut self, retriever: Arc<dyn Retriever>) -> Self {
self.retriever = Some(retriever);
self
}
pub fn with_llm<L: Into<Box<dyn LLM>>>(mut self, llm: L) -> Self {
self.llm = Some(llm.into());
self
}
pub fn with_memory(mut self, memory: Arc<tokio::sync::Mutex<dyn BaseMemory>>) -> Self {
self.memory = Some(memory);
self
}
pub fn with_query_enhancer(mut self, enhancer: Arc<dyn QueryEnhancer>) -> Self {
self.query_enhancer = Some(enhancer);
self
}
pub fn with_retrieval_validator(mut self, validator: Arc<dyn RetrievalValidator>) -> Self {
self.retrieval_validator = Some(validator);
self
}
pub fn with_answer_validator(mut self, validator: Arc<dyn AnswerValidator>) -> Self {
self.answer_validator = Some(validator);
self
}
pub fn with_config(mut self, config: HybridRAGConfig) -> Self {
self.config = config;
self
}
pub fn build(self) -> Result<HybridRAG, RAGError> {
let retriever = self
.retriever
.ok_or_else(|| RAGError::InvalidConfiguration("Retriever must be set".to_string()))?;
let llm = self
.llm
.ok_or_else(|| RAGError::InvalidConfiguration("LLM must be set".to_string()))?;
let memory = self
.memory
.unwrap_or_else(|| Arc::new(tokio::sync::Mutex::new(SimpleMemory::new())));
let mut hybrid_rag = HybridRAG::new(retriever, llm, memory);
if let Some(enhancer) = self.query_enhancer {
hybrid_rag = hybrid_rag.with_query_enhancer(enhancer);
}
if let Some(validator) = self.retrieval_validator {
hybrid_rag = hybrid_rag.with_retrieval_validator(validator);
}
if let Some(validator) = self.answer_validator {
hybrid_rag = hybrid_rag.with_answer_validator(validator);
}
hybrid_rag = hybrid_rag.with_config(self.config);
Ok(hybrid_rag)
}
}
impl Default for HybridRAGBuilder {
fn default() -> Self {
Self::new()
}
}