rag_toolchain/chains/basic_rag_chain.rs
1use crate::{
2 chains::{utils::build_prompt, RagChainError},
3 clients::{AsyncChatClient, AsyncStreamedChatClient, PromptMessage},
4 common::Chunks,
5 retrievers::AsyncRetriever,
6};
7use std::num::NonZeroU32;
8use typed_builder::TypedBuilder;
9
10/// # [`BasicRAGChain`]
11///
12/// This struct allows for easily executing RAG given a single user prompt.
13/// the current implementation relies on async chat clients and async retrievers.
14/// we use generics in order to preserve error types via associated types.
15///
16/// * `T` - The type of the chat client to be used
17/// * `U` - The type of the retriever to be used
18///
19/// # Examples
20/// ```
21/// use rag_toolchain::clients::*;
22/// use rag_toolchain::retrievers::*;
23/// use rag_toolchain::chains::*;
24/// use rag_toolchain::stores::*;
25/// use rag_toolchain::common::*;
26/// use std::num::NonZeroU32;
27///
28/// async fn run_chain() {
29///
30/// const SYSTEM_MESSAGE: &'static str =
31/// "You are to give straight forward answers using the supporting information you are provided";
32///
33/// let store: PostgresVectorStore =
34/// PostgresVectorStore::try_new("embeddings", OpenAIEmbeddingModel::TextEmbeddingAda002)
35/// .await
36/// .unwrap();
37///
38/// let embedding_client: OpenAIEmbeddingClient =
39/// OpenAIEmbeddingClient::try_new(OpenAIEmbeddingModel::TextEmbeddingAda002).unwrap();
40///
41/// let retriever: PostgresVectorRetriever<OpenAIEmbeddingClient> =
42/// store.as_retriever(embedding_client, DistanceFunction::Cosine);
43///
44/// let chat_client: OpenAIChatCompletionClient =
45/// OpenAIChatCompletionClient::try_new(OpenAIModel::Gpt3Point5Turbo).unwrap();
46///
47/// let system_prompt: PromptMessage = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
48///
49/// let chain: BasicRAGChain<OpenAIChatCompletionClient, PostgresVectorRetriever<_>> =
50/// BasicRAGChain::builder()
51/// .system_prompt(system_prompt)
52/// .chat_client(chat_client)
53/// .retriever(retriever)
54/// .build();
55/// let user_message: PromptMessage =
56/// PromptMessage::HumanMessage("what kind of alcohol does Morwenna drink".into());
57///
58/// let response = chain
59/// .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
60/// .await
61/// .unwrap();
62/// }
63/// ```
64#[derive(Debug, TypedBuilder, Clone, PartialEq, Eq)]
65pub struct BasicRAGChain<T, U>
66where
67 T: AsyncChatClient,
68 U: AsyncRetriever,
69{
70 #[builder(default, setter(strip_option))]
71 system_prompt: Option<PromptMessage>,
72 chat_client: T,
73 retriever: U,
74}
75
76impl<T, U> BasicRAGChain<T, U>
77where
78 T: AsyncChatClient,
79 U: AsyncRetriever,
80{
81 /// # [`BasicRAGChain::invoke_chain`]
82 ///
83 /// function to execute the RAG chain given a user prompt and a top_k value.
84 /// we take the supplied user prompt and retrieve supporting chunks from the retriever.
85 /// those chunks are then used to build a new prompt which is then sent to the chat client.
86 /// the new prompt then becomes:
87 ///
88 /// user prompt
89 ///
90 /// Here is some supporting information:
91 ///
92 /// chunk1
93 ///
94 /// chunk2
95 ///
96 /// ...
97 ///
98 /// # Arguments
99 /// * `user_message`: [`PromptMessage`] - the user prompt, this will be used to retrieve supporting chunks
100 /// * `top_k`: [`NonZeroU32`] - the number of supporting chunks to retrieve
101 ///
102 /// # Errors
103 /// * [`RagChainError`] - if the chat client or retriever fails.
104 ///
105 /// # Returns
106 /// [`PromptMessage`] - the response from the chat client
107 pub async fn invoke_chain(
108 &self,
109 user_message: PromptMessage,
110 top_k: NonZeroU32,
111 ) -> Result<PromptMessage, RagChainError<T::ErrorType, U::ErrorType>> {
112 let content = user_message.content();
113 let chunks: Chunks = self
114 .retriever
115 .retrieve(content, top_k)
116 .await
117 .map_err(RagChainError::RetrieverError::<T::ErrorType, U::ErrorType>)?;
118
119 let new_prompt: PromptMessage = build_prompt(&user_message, chunks);
120
121 let prompts = match self.system_prompt.clone() {
122 None => vec![new_prompt],
123 Some(prompt) => vec![prompt, new_prompt],
124 };
125
126 let result = self
127 .chat_client
128 .invoke(prompts)
129 .await
130 .map_err(RagChainError::ChatClientError::<T::ErrorType, U::ErrorType>)?;
131
132 Ok(result)
133 }
134}
135
136/// # [`BasicStreamedRAGChain`]
137///
138/// This struct allows for easily executing RAG given a single user prompt.
139/// the current implementation relies on async streamed chat clients and async retrievers.
140/// we use generics in order to preserve error types via associated types.
141///
142/// * `T` - The type of the streamed chat client to be used
143/// * `U` - The type of the retriever to be used
144///
145/// # Examples
146/// ```
147/// use rag_toolchain::clients::*;
148/// use rag_toolchain::retrievers::*;
149/// use rag_toolchain::chains::*;
150/// use rag_toolchain::stores::*;
151/// use rag_toolchain::common::*;
152/// use std::num::NonZeroU32;
153///
154/// async fn run_chain() {
155///
156/// const SYSTEM_MESSAGE: &'static str =
157/// "You are to give straight forward answers using the supporting information you are provided";
158///
159/// let store: PostgresVectorStore =
160/// PostgresVectorStore::try_new("embeddings", OpenAIEmbeddingModel::TextEmbeddingAda002)
161/// .await
162/// .unwrap();
163///
164/// let embedding_client: OpenAIEmbeddingClient =
165/// OpenAIEmbeddingClient::try_new(OpenAIEmbeddingModel::TextEmbeddingAda002).unwrap();
166///
167/// let retriever: PostgresVectorRetriever<OpenAIEmbeddingClient> =
168/// store.as_retriever(embedding_client, DistanceFunction::Cosine);
169///
170/// let chat_client: OpenAIChatCompletionClient =
171/// OpenAIChatCompletionClient::try_new(OpenAIModel::Gpt3Point5Turbo).unwrap();
172///
173/// let system_prompt: PromptMessage = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
174///
175/// let chain: BasicStreamedRAGChain<OpenAIChatCompletionClient, PostgresVectorRetriever<_>> =
176/// BasicStreamedRAGChain::builder()
177/// .system_prompt(system_prompt)
178/// .chat_client(chat_client)
179/// .retriever(retriever)
180/// .build();
181/// let user_message: PromptMessage =
182/// PromptMessage::HumanMessage("what kind of alcohol does Morwenna drink".into());
183///
184/// let stream = chain
185/// .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
186/// .await
187/// .unwrap();
188/// }
189/// ```
190#[derive(Debug, TypedBuilder, Clone, PartialEq, Eq)]
191pub struct BasicStreamedRAGChain<T, U>
192where
193 T: AsyncStreamedChatClient,
194 U: AsyncRetriever,
195{
196 #[builder(default, setter(strip_option))]
197 system_prompt: Option<PromptMessage>,
198 chat_client: T,
199 retriever: U,
200}
201
202impl<T, U> BasicStreamedRAGChain<T, U>
203where
204 T: AsyncStreamedChatClient,
205 U: AsyncRetriever,
206{
207 pub async fn invoke_chain(
208 &self,
209 user_message: PromptMessage,
210 top_k: NonZeroU32,
211 ) -> Result<T::Item, RagChainError<T::ErrorType, U::ErrorType>> {
212 let content = user_message.content();
213 let chunks: Chunks = self
214 .retriever
215 .retrieve(content, top_k)
216 .await
217 .map_err(RagChainError::RetrieverError::<T::ErrorType, U::ErrorType>)?;
218
219 let new_prompt: PromptMessage = build_prompt(&user_message, chunks);
220
221 let prompts = match self.system_prompt.clone() {
222 None => vec![new_prompt],
223 Some(prompt) => vec![prompt, new_prompt],
224 };
225
226 let result = self
227 .chat_client
228 .invoke_stream(prompts)
229 .await
230 .map_err(RagChainError::ChatClientError::<T::ErrorType, U::ErrorType>)?;
231
232 Ok(result)
233 }
234}
235
236#[cfg(test)]
237mod basic_rag_chain_tests {
238 use super::*;
239 use crate::{
240 clients::{
241 ChatCompletionStream, MockAsyncChatClient, MockAsyncStreamedChatClient,
242 MockChatCompletionStream,
243 },
244 common::Chunk,
245 retrievers::MockAsyncRetriever,
246 };
247 use mockall::predicate::eq;
248 use std::vec;
249
250 #[tokio::test]
251 async fn test_chain_succeeds() {
252 const SYSTEM_MESSAGE: &str = "you are a study buddy";
253 const USER_MESSAGE: &str = "please tell me about my lecture on operating systems";
254 const RAG_CHUNK_1: &str = "data point 1";
255 const RAG_CHUNK_2: &str = "data point 2";
256 let expected_user_message: String = format!(
257 "{}\n{}\n{}\n{}\n",
258 USER_MESSAGE, "Here is some supporting information:", RAG_CHUNK_1, RAG_CHUNK_2
259 );
260
261 let system_prompt = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
262 let mut chat_client = MockAsyncChatClient::new();
263 let mut retriever = MockAsyncRetriever::new();
264
265 retriever
266 .expect_retrieve()
267 .with(eq(USER_MESSAGE), eq(NonZeroU32::new(2).unwrap()))
268 .returning(|_, _| Ok(vec![Chunk::new(RAG_CHUNK_1), Chunk::new(RAG_CHUNK_2)]));
269
270 chat_client
271 .expect_invoke()
272 .with(eq(vec![
273 system_prompt.clone(),
274 PromptMessage::HumanMessage(expected_user_message.into()),
275 ]))
276 .returning(|_| Ok(PromptMessage::AIMessage("mocked response".into())));
277
278 let chain: BasicRAGChain<MockAsyncChatClient, MockAsyncRetriever> =
279 BasicRAGChain::builder()
280 .system_prompt(system_prompt)
281 .chat_client(chat_client)
282 .retriever(retriever)
283 .build();
284
285 let user_message = PromptMessage::HumanMessage(USER_MESSAGE.into());
286
287 let result = chain
288 .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
289 .await
290 .unwrap();
291
292 assert_eq!(PromptMessage::AIMessage("mocked response".into()), result)
293 }
294
295 #[tokio::test]
296 async fn test_streamed_chain_succeeds() {
297 const SYSTEM_MESSAGE: &str = "you are a study buddy";
298 const USER_MESSAGE: &str = "please tell me about my lecture on operating systems";
299 const RAG_CHUNK_1: &str = "data point 1";
300 const RAG_CHUNK_2: &str = "data point 2";
301 let expected_user_message: String = format!(
302 "{}\n{}\n{}\n{}\n",
303 USER_MESSAGE, "Here is some supporting information:", RAG_CHUNK_1, RAG_CHUNK_2
304 );
305
306 let system_prompt = PromptMessage::SystemMessage(SYSTEM_MESSAGE.into());
307 let mut chat_client = MockAsyncStreamedChatClient::new();
308 let mut retriever = MockAsyncRetriever::new();
309 retriever
310 .expect_retrieve()
311 .with(eq(USER_MESSAGE), eq(NonZeroU32::new(2).unwrap()))
312 .returning(|_, _| Ok(vec![Chunk::new(RAG_CHUNK_1), Chunk::new(RAG_CHUNK_2)]));
313
314 chat_client
315 .expect_invoke_stream()
316 .with(eq(vec![
317 system_prompt.clone(),
318 PromptMessage::HumanMessage(expected_user_message.into()),
319 ]))
320 .returning(move |_| {
321 let mut stream = MockChatCompletionStream::new();
322 stream
323 .expect_next()
324 .returning(|| Some(Ok(PromptMessage::AIMessage("mocked response".into()))));
325 Ok(stream)
326 });
327
328 let chain: BasicStreamedRAGChain<MockAsyncStreamedChatClient, MockAsyncRetriever> =
329 BasicStreamedRAGChain::builder()
330 .system_prompt(system_prompt)
331 .chat_client(chat_client)
332 .retriever(retriever)
333 .build();
334
335 let user_message = PromptMessage::HumanMessage(USER_MESSAGE.into());
336
337 let mut result = chain
338 .invoke_chain(user_message, NonZeroU32::new(2).unwrap())
339 .await
340 .unwrap();
341
342 assert_eq!(
343 result.next().await.unwrap().unwrap(),
344 PromptMessage::AIMessage("mocked response".into())
345 );
346 }
347}