use crate::{
chains::{utils::build_prompt, RagChainError},
clients::{AsyncChatClient, AsyncStreamedChatClient, PromptMessage},
common::Chunks,
retrievers::AsyncRetriever,
};
use std::num::NonZeroU32;
use typed_builder::TypedBuilder;
#[derive(Debug, TypedBuilder, Clone, PartialEq, Eq)]
pub struct BasicRAGChain<T, U>
where
T: AsyncChatClient,
U: AsyncRetriever,
{
#[builder(default, setter(strip_option))]
system_prompt: Option<PromptMessage>,
chat_client: T,
retriever: U,
}
impl<T, U> BasicRAGChain<T, U>
where
T: AsyncChatClient,
U: AsyncRetriever,
{
pub async fn invoke_chain(
&self,
user_message: PromptMessage,
top_k: NonZeroU32,
) -> Result<PromptMessage, RagChainError<T::ErrorType, U::ErrorType>> {
let content = user_message.content();
let chunks: Chunks = self
.retriever
.retrieve(content, top_k)
.await
.map_err(RagChainError::RetrieverError::<T::ErrorType, U::ErrorType>)?;
let new_prompt: PromptMessage = build_prompt(&user_message, chunks);
let prompts = match self.system_prompt.clone() {
None => vec![new_prompt],
Some(prompt) => vec![prompt, new_prompt],
};
let result = self
.chat_client
.invoke(prompts)
.await
.map_err(RagChainError::ChatClientError::<T::ErrorType, U::ErrorType>)?;
Ok(result)
}
}
#[derive(Debug, TypedBuilder, Clone, PartialEq, Eq)]
pub struct BasicStreamedRAGChain<T, U>
where
T: AsyncStreamedChatClient,
U: AsyncRetriever,
{
#[builder(default, setter(strip_option))]
system_prompt: Option<PromptMessage>,
chat_client: T,
retriever: U,
}
impl<T, U> BasicStreamedRAGChain<T, U>
where
T: AsyncStreamedChatClient,
U: AsyncRetriever,
{
pub async fn invoke_chain(
&self,
user_message: PromptMessage,
top_k: NonZeroU32,
) -> Result<T::Item, RagChainError<T::ErrorType, U::ErrorType>> {
let content = user_message.content();
let chunks: Chunks = self
.retriever
.retrieve(content, top_k)
.await
.map_err(RagChainError::RetrieverError::<T::ErrorType, U::ErrorType>)?;
let new_prompt: PromptMessage = build_prompt(&user_message, chunks);
let prompts = match self.system_prompt.clone() {
None => vec![new_prompt],
Some(prompt) => vec![prompt, new_prompt],
};
let result = self
.chat_client
.invoke_stream(prompts)
.await
.map_err(RagChainError::ChatClientError::<T::ErrorType, U::ErrorType>)?;
Ok(result)
}
}
#[cfg(test)]
mod basic_rag_chain_tests {
use super::*;
use crate::{
clients::{
ChatCompletionStream, MockAsyncChatClient, MockAsyncStreamedChatClient,
MockChatCompletionStream,
},
common::Chunk,
retrievers::MockAsyncRetriever,
};
use mockall::predicate::eq;
use std::vec;
#[tokio::test]
async fn test_chain_succeeds() {
const SYSTEM_MESSAGE: &str = "you are a study buddy";
const USER_MESSAGE: &str = "please tell me about my lecture on operating systems";
const RAG_CHUNK_1: &str = "data point 1";
const RAG_CHUNK_2: &str = "data point 2";
let expected_user_message: String = format!(
"{}\n{}\n{}\n{}\n",
USER_MESSAGE, "Here is some supporting information:", RAG_CHUNK_1, RAG_CHUNK_2
);
let system_prompt = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
let mut chat_client = MockAsyncChatClient::new();
let mut retriever = MockAsyncRetriever::new();
retriever
.expect_retrieve()
.with(eq(USER_MESSAGE), eq(NonZeroU32::new(2).unwrap()))
.returning(|_, _| Ok(vec![Chunk::new(RAG_CHUNK_1), Chunk::new(RAG_CHUNK_2)]));
chat_client
.expect_invoke()
.with(eq(vec![
system_prompt.clone(),
PromptMessage::HumanMessage(expected_user_message.into()),
]))
.returning(|_| Ok(PromptMessage::AIMessage("mocked response".into())));
let chain: BasicRAGChain<MockAsyncChatClient, MockAsyncRetriever> =
BasicRAGChain::builder()
.system_prompt(system_prompt)
.chat_client(chat_client)
.retriever(retriever)
.build();
let user_message = PromptMessage::HumanMessage(USER_MESSAGE.into());
let result = chain
.invoke_chain(user_message, NonZeroU32::new(2).unwrap())
.await
.unwrap();
assert_eq!(PromptMessage::AIMessage("mocked response".into()), result)
}
#[tokio::test]
async fn test_streamed_chain_succeeds() {
const SYSTEM_MESSAGE: &str = "you are a study buddy";
const USER_MESSAGE: &str = "please tell me about my lecture on operating systems";
const RAG_CHUNK_1: &str = "data point 1";
const RAG_CHUNK_2: &str = "data point 2";
let expected_user_message: String = format!(
"{}\n{}\n{}\n{}\n",
USER_MESSAGE, "Here is some supporting information:", RAG_CHUNK_1, RAG_CHUNK_2
);
let system_prompt = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
let mut chat_client = MockAsyncStreamedChatClient::new();
let mut retriever = MockAsyncRetriever::new();
retriever
.expect_retrieve()
.with(eq(USER_MESSAGE), eq(NonZeroU32::new(2).unwrap()))
.returning(|_, _| Ok(vec![Chunk::new(RAG_CHUNK_1), Chunk::new(RAG_CHUNK_2)]));
chat_client
.expect_invoke_stream()
.with(eq(vec![
system_prompt.clone(),
PromptMessage::HumanMessage(expected_user_message.into()),
]))
.returning(move |_| {
let mut stream = MockChatCompletionStream::new();
stream
.expect_next()
.returning(|| Some(Ok(PromptMessage::AIMessage("mocked response".into()))));
Ok(stream)
});
let chain: BasicStreamedRAGChain<MockAsyncStreamedChatClient, MockAsyncRetriever> =
BasicStreamedRAGChain::builder()
.system_prompt(system_prompt)
.chat_client(chat_client)
.retriever(retriever)
.build();
let user_message = PromptMessage::HumanMessage(USER_MESSAGE.into());
let mut result = chain
.invoke_chain(user_message, NonZeroU32::new(2).unwrap())
.await
.unwrap();
assert_eq!(
result.next().await.unwrap().unwrap(),
PromptMessage::AIMessage("mocked response".into())
);
}
}