use crate::{
chain::{Chain, ChainError, ConversationalRetrieverChain},
language_models::GenerateResult,
prompt::PromptArgs,
};
use crate::rag::RAGError;
pub struct TwoStepRAG {
chain: ConversationalRetrieverChain,
}
impl TwoStepRAG {
pub fn from_chain(chain: ConversationalRetrieverChain) -> Self {
Self { chain }
}
pub async fn invoke(&self, question: &str) -> Result<String, RAGError> {
let mut prompt_args = PromptArgs::new();
prompt_args.insert("question".to_string(), serde_json::json!(question));
let result = self.chain.invoke(prompt_args).await?;
Ok(result)
}
pub async fn call(&self, question: &str) -> Result<GenerateResult, RAGError> {
let mut prompt_args = PromptArgs::new();
prompt_args.insert("question".to_string(), serde_json::json!(question));
let result = self.chain.call(prompt_args).await?;
Ok(result)
}
pub fn chain(&self) -> &ConversationalRetrieverChain {
&self.chain
}
}
#[async_trait::async_trait]
impl Chain for TwoStepRAG {
async fn call(&self, input_variables: PromptArgs) -> Result<GenerateResult, ChainError> {
self.chain.call(input_variables).await
}
async fn invoke(&self, input_variables: PromptArgs) -> Result<String, ChainError> {
self.chain.invoke(input_variables).await
}
}