rag_toolchain/chains/
basic_rag_chain.rs

1use crate::{
2    chains::{utils::build_prompt, RagChainError},
3    clients::{AsyncChatClient, AsyncStreamedChatClient, PromptMessage},
4    common::Chunks,
5    retrievers::AsyncRetriever,
6};
7use std::num::NonZeroU32;
8use typed_builder::TypedBuilder;
9
10/// # [`BasicRAGChain`]
11///
12/// This struct allows for easily executing RAG given a single user prompt.
13/// the current implementation relies on async chat clients and async retrievers.
14/// we use generics in order to preserve error types via associated types.
15///
16/// * `T` - The type of the chat client to be used
17/// * `U` - The type of the retriever to be used
18///
19/// # Examples
20/// ```
21/// use rag_toolchain::clients::*;
22/// use rag_toolchain::retrievers::*;
23/// use rag_toolchain::chains::*;
24/// use rag_toolchain::stores::*;
25/// use rag_toolchain::common::*;
26/// use std::num::NonZeroU32;
27///
28/// async fn run_chain() {
29///
30///    const SYSTEM_MESSAGE: &'static str =
31///         "You are to give straight forward answers using the supporting information you are provided";
32///
33///    let store: PostgresVectorStore =
34///    PostgresVectorStore::try_new("embeddings", OpenAIEmbeddingModel::TextEmbeddingAda002)
35///        .await
36///        .unwrap();
37///
38///    let embedding_client: OpenAIEmbeddingClient =
39///        OpenAIEmbeddingClient::try_new(OpenAIEmbeddingModel::TextEmbeddingAda002).unwrap();
40///
41///    let retriever: PostgresVectorRetriever<OpenAIEmbeddingClient> =
42///        store.as_retriever(embedding_client, DistanceFunction::Cosine);
43///
44///    let chat_client: OpenAIChatCompletionClient =
45///        OpenAIChatCompletionClient::try_new(OpenAIModel::Gpt3Point5Turbo).unwrap();
46///
47///    let system_prompt: PromptMessage = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
48///
49///    let chain: BasicRAGChain<OpenAIChatCompletionClient, PostgresVectorRetriever<_>> =
50///        BasicRAGChain::builder()
51///            .system_prompt(system_prompt)
52///            .chat_client(chat_client)
53///            .retriever(retriever)
54///            .build();
55///    let user_message: PromptMessage =
56///        PromptMessage::HumanMessage("what kind of alcohol does Morwenna drink".into());
57///
58///    let response = chain
59///        .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
60///        .await
61///        .unwrap();
62/// }
63/// ```
64#[derive(Debug, TypedBuilder, Clone, PartialEq, Eq)]
65pub struct BasicRAGChain<T, U>
66where
67    T: AsyncChatClient,
68    U: AsyncRetriever,
69{
70    #[builder(default, setter(strip_option))]
71    system_prompt: Option<PromptMessage>,
72    chat_client: T,
73    retriever: U,
74}
75
76impl<T, U> BasicRAGChain<T, U>
77where
78    T: AsyncChatClient,
79    U: AsyncRetriever,
80{
81    /// # [`BasicRAGChain::invoke_chain`]
82    ///
83    /// function to execute the RAG chain given a user prompt and a top_k value.
84    /// we take the supplied user prompt and retrieve supporting chunks from the retriever.
85    /// those chunks are then used to build a new prompt which is then sent to the chat client.
86    /// the new prompt then becomes:
87    ///
88    /// user prompt
89    ///
90    /// Here is some supporting information:
91    ///
92    /// chunk1
93    ///
94    /// chunk2
95    ///
96    /// ...
97    ///
98    /// # Arguments
99    /// * `user_message`: [`PromptMessage`] - the user prompt, this will be used to retrieve supporting chunks
100    /// * `top_k`: [`NonZeroU32`] - the number of supporting chunks to retrieve
101    ///
102    /// # Errors
103    /// * [`RagChainError`] - if the chat client or retriever fails.
104    ///
105    /// # Returns
106    /// [`PromptMessage`] - the response from the chat client
107    pub async fn invoke_chain(
108        &self,
109        user_message: PromptMessage,
110        top_k: NonZeroU32,
111    ) -> Result<PromptMessage, RagChainError<T::ErrorType, U::ErrorType>> {
112        let content = user_message.content();
113        let chunks: Chunks = self
114            .retriever
115            .retrieve(content, top_k)
116            .await
117            .map_err(RagChainError::RetrieverError::<T::ErrorType, U::ErrorType>)?;
118
119        let new_prompt: PromptMessage = build_prompt(&user_message, chunks);
120
121        let prompts = match self.system_prompt.clone() {
122            None => vec![new_prompt],
123            Some(prompt) => vec![prompt, new_prompt],
124        };
125
126        let result = self
127            .chat_client
128            .invoke(prompts)
129            .await
130            .map_err(RagChainError::ChatClientError::<T::ErrorType, U::ErrorType>)?;
131
132        Ok(result)
133    }
134}
135
136/// # [`BasicStreamedRAGChain`]
137///
138/// This struct allows for easily executing RAG given a single user prompt.
139/// the current implementation relies on async streamed chat clients and async retrievers.
140/// we use generics in order to preserve error types via associated types.
141///
142/// * `T` - The type of the streamed chat client to be used
143/// * `U` - The type of the retriever to be used
144///
145/// # Examples
146/// ```
147/// use rag_toolchain::clients::*;
148/// use rag_toolchain::retrievers::*;
149/// use rag_toolchain::chains::*;
150/// use rag_toolchain::stores::*;
151/// use rag_toolchain::common::*;
152/// use std::num::NonZeroU32;
153///
154/// async fn run_chain() {
155///
156///    const SYSTEM_MESSAGE: &'static str =
157///         "You are to give straight forward answers using the supporting information you are provided";
158///
159///    let store: PostgresVectorStore =
160///    PostgresVectorStore::try_new("embeddings", OpenAIEmbeddingModel::TextEmbeddingAda002)
161///        .await
162///        .unwrap();
163///
164///    let embedding_client: OpenAIEmbeddingClient =
165///        OpenAIEmbeddingClient::try_new(OpenAIEmbeddingModel::TextEmbeddingAda002).unwrap();
166///
167///    let retriever: PostgresVectorRetriever<OpenAIEmbeddingClient> =
168///        store.as_retriever(embedding_client, DistanceFunction::Cosine);
169///
170///    let chat_client: OpenAIChatCompletionClient =
171///        OpenAIChatCompletionClient::try_new(OpenAIModel::Gpt3Point5Turbo).unwrap();
172///
173///    let system_prompt: PromptMessage = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
174///
175///    let chain: BasicStreamedRAGChain<OpenAIChatCompletionClient, PostgresVectorRetriever<_>> =
176///        BasicStreamedRAGChain::builder()
177///            .system_prompt(system_prompt)
178///            .chat_client(chat_client)
179///            .retriever(retriever)
180///            .build();
181///    let user_message: PromptMessage =
182///        PromptMessage::HumanMessage("what kind of alcohol does Morwenna drink".into());
183///
184///    let stream = chain
185///        .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
186///        .await
187///        .unwrap();
188/// }
189/// ```
190#[derive(Debug, TypedBuilder, Clone, PartialEq, Eq)]
191pub struct BasicStreamedRAGChain<T, U>
192where
193    T: AsyncStreamedChatClient,
194    U: AsyncRetriever,
195{
196    #[builder(default, setter(strip_option))]
197    system_prompt: Option<PromptMessage>,
198    chat_client: T,
199    retriever: U,
200}
201
202impl<T, U> BasicStreamedRAGChain<T, U>
203where
204    T: AsyncStreamedChatClient,
205    U: AsyncRetriever,
206{
207    pub async fn invoke_chain(
208        &self,
209        user_message: PromptMessage,
210        top_k: NonZeroU32,
211    ) -> Result<T::Item, RagChainError<T::ErrorType, U::ErrorType>> {
212        let content = user_message.content();
213        let chunks: Chunks = self
214            .retriever
215            .retrieve(content, top_k)
216            .await
217            .map_err(RagChainError::RetrieverError::<T::ErrorType, U::ErrorType>)?;
218
219        let new_prompt: PromptMessage = build_prompt(&user_message, chunks);
220
221        let prompts = match self.system_prompt.clone() {
222            None => vec![new_prompt],
223            Some(prompt) => vec![prompt, new_prompt],
224        };
225
226        let result = self
227            .chat_client
228            .invoke_stream(prompts)
229            .await
230            .map_err(RagChainError::ChatClientError::<T::ErrorType, U::ErrorType>)?;
231
232        Ok(result)
233    }
234}
235
236#[cfg(test)]
237mod basic_rag_chain_tests {
238    use super::*;
239    use crate::{
240        clients::{
241            ChatCompletionStream, MockAsyncChatClient, MockAsyncStreamedChatClient,
242            MockChatCompletionStream,
243        },
244        common::Chunk,
245        retrievers::MockAsyncRetriever,
246    };
247    use mockall::predicate::eq;
248    use std::vec;
249
250    #[tokio::test]
251    async fn test_chain_succeeds() {
252        const SYSTEM_MESSAGE: &str = "you are a study buddy";
253        const USER_MESSAGE: &str = "please tell me about my lecture on operating systems";
254        const RAG_CHUNK_1: &str = "data point 1";
255        const RAG_CHUNK_2: &str = "data point 2";
256        let expected_user_message: String = format!(
257            "{}\n{}\n{}\n{}\n",
258            USER_MESSAGE, "Here is some supporting information:", RAG_CHUNK_1, RAG_CHUNK_2
259        );
260
261        let system_prompt = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
262        let mut chat_client = MockAsyncChatClient::new();
263        let mut retriever = MockAsyncRetriever::new();
264
265        retriever
266            .expect_retrieve()
267            .with(eq(USER_MESSAGE), eq(NonZeroU32::new(2).unwrap()))
268            .returning(|_, _| Ok(vec![Chunk::new(RAG_CHUNK_1), Chunk::new(RAG_CHUNK_2)]));
269
270        chat_client
271            .expect_invoke()
272            .with(eq(vec![
273                system_prompt.clone(),
274                PromptMessage::HumanMessage(expected_user_message.into()),
275            ]))
276            .returning(|_| Ok(PromptMessage::AIMessage("mocked response".into())));
277
278        let chain: BasicRAGChain<MockAsyncChatClient, MockAsyncRetriever> =
279            BasicRAGChain::builder()
280                .system_prompt(system_prompt)
281                .chat_client(chat_client)
282                .retriever(retriever)
283                .build();
284
285        let user_message = PromptMessage::HumanMessage(USER_MESSAGE.into());
286
287        let result = chain
288            .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
289            .await
290            .unwrap();
291
292        assert_eq!(PromptMessage::AIMessage("mocked response".into()), result)
293    }
294
295    #[tokio::test]
296    async fn test_streamed_chain_succeeds() {
297        const SYSTEM_MESSAGE: &str = "you are a study buddy";
298        const USER_MESSAGE: &str = "please tell me about my lecture on operating systems";
299        const RAG_CHUNK_1: &str = "data point 1";
300        const RAG_CHUNK_2: &str = "data point 2";
301        let expected_user_message: String = format!(
302            "{}\n{}\n{}\n{}\n",
303            USER_MESSAGE, "Here is some supporting information:", RAG_CHUNK_1, RAG_CHUNK_2
304        );
305
306        let system_prompt = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
307        let mut chat_client = MockAsyncStreamedChatClient::new();
308        let mut retriever = MockAsyncRetriever::new();
309        retriever
310            .expect_retrieve()
311            .with(eq(USER_MESSAGE), eq(NonZeroU32::new(2).unwrap()))
312            .returning(|_, _| Ok(vec![Chunk::new(RAG_CHUNK_1), Chunk::new(RAG_CHUNK_2)]));
313
314        chat_client
315            .expect_invoke_stream()
316            .with(eq(vec![
317                system_prompt.clone(),
318                PromptMessage::HumanMessage(expected_user_message.into()),
319            ]))
320            .returning(move |_| {
321                let mut stream = MockChatCompletionStream::new();
322                stream
323                    .expect_next()
324                    .returning(|| Some(Ok(PromptMessage::AIMessage("mocked response".into()))));
325                Ok(stream)
326            });
327
328        let chain: BasicStreamedRAGChain<MockAsyncStreamedChatClient, MockAsyncRetriever> =
329            BasicStreamedRAGChain::builder()
330                .system_prompt(system_prompt)
331                .chat_client(chat_client)
332                .retriever(retriever)
333                .build();
334
335        let user_message = PromptMessage::HumanMessage(USER_MESSAGE.into());
336
337        let mut result = chain
338            .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
339            .await
340            .unwrap();
341
342        assert_eq!(
343            result.next().await.unwrap().unwrap(),
344            PromptMessage::AIMessage("mocked response".into())
345        );
346    }
347}