rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest};
2use crate::{
3    completion::{
4        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
5        Message, Prompt, PromptError,
6    },
7    streaming::{StreamingChat, StreamingCompletion, StreamingCompletionResponse, StreamingPrompt},
8    tool::ToolSet,
9    vector_store::{VectorStoreError, request::VectorSearchRequest},
10};
11use futures::{StreamExt, TryStreamExt, stream};
12use std::collections::HashMap;
13
14/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
15/// (i.e.: system prompt) and a static set of context documents and tools.
16/// All context documents and tools are always provided to the agent when prompted.
17///
18/// # Example
19/// ```
20/// use rig::{completion::Prompt, providers::openai};
21///
22/// let openai = openai::Client::from_env();
23///
24/// let comedian_agent = openai
25///     .agent("gpt-4o")
26///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
27///     .temperature(0.9)
28///     .build();
29///
30/// let response = comedian_agent.prompt("Entertain me!")
31///     .await
32///     .expect("Failed to prompt the agent");
33/// ```
34pub struct Agent<M: CompletionModel> {
35    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
36    pub model: M,
37    /// System prompt
38    pub preamble: String,
39    /// Context documents always available to the agent
40    pub static_context: Vec<Document>,
41    /// Tools that are always available to the agent (identified by their name)
42    pub static_tools: Vec<String>,
43    /// Temperature of the model
44    pub temperature: Option<f64>,
45    /// Maximum number of tokens for the completion
46    pub max_tokens: Option<u64>,
47    /// Additional parameters to be passed to the model
48    pub additional_params: Option<serde_json::Value>,
49    /// List of vector store, with the sample number
50    pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
51    /// Dynamic tools
52    pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
53    /// Actual tool implementations
54    pub tools: ToolSet,
55}
56
57impl<M: CompletionModel> Completion<M> for Agent<M> {
58    async fn completion(
59        &self,
60        prompt: impl Into<Message> + Send,
61        chat_history: Vec<Message>,
62    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
63        let prompt = prompt.into();
64
65        // Find the latest message in the chat history that contains RAG text
66        let rag_text = prompt.rag_text();
67        let rag_text = rag_text.or_else(|| {
68            chat_history
69                .iter()
70                .rev()
71                .find_map(|message| message.rag_text())
72        });
73
74        let completion_request = self
75            .model
76            .completion_request(prompt)
77            .preamble(self.preamble.clone())
78            .messages(chat_history)
79            .temperature_opt(self.temperature)
80            .max_tokens_opt(self.max_tokens)
81            .additional_params_opt(self.additional_params.clone())
82            .documents(self.static_context.clone());
83
84        // If the agent has RAG text, we need to fetch the dynamic context and tools
85        let agent = match &rag_text {
86            Some(text) => {
87                let dynamic_context = stream::iter(self.dynamic_context.iter())
88                    .then(|(num_sample, index)| async {
89                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
90                        Ok::<_, VectorStoreError>(
91                            index
92                                .top_n(req)
93                                .await?
94                                .into_iter()
95                                .map(|(_, id, doc)| {
96                                    // Pretty print the document if possible for better readability
97                                    let text = serde_json::to_string_pretty(&doc)
98                                        .unwrap_or_else(|_| doc.to_string());
99
100                                    Document {
101                                        id,
102                                        text,
103                                        additional_props: HashMap::new(),
104                                    }
105                                })
106                                .collect::<Vec<_>>(),
107                        )
108                    })
109                    .try_fold(vec![], |mut acc, docs| async {
110                        acc.extend(docs);
111                        Ok(acc)
112                    })
113                    .await
114                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
115
116                let dynamic_tools = stream::iter(self.dynamic_tools.iter())
117                    .then(|(num_sample, index)| async {
118                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
119                        Ok::<_, VectorStoreError>(
120                            index
121                                .top_n_ids(req)
122                                .await?
123                                .into_iter()
124                                .map(|(_, id)| id)
125                                .collect::<Vec<_>>(),
126                        )
127                    })
128                    .try_fold(vec![], |mut acc, docs| async {
129                        for doc in docs {
130                            if let Some(tool) = self.tools.get(&doc) {
131                                acc.push(tool.definition(text.into()).await)
132                            } else {
133                                tracing::warn!("Tool implementation not found in toolset: {}", doc);
134                            }
135                        }
136                        Ok(acc)
137                    })
138                    .await
139                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
140
141                let static_tools = stream::iter(self.static_tools.iter())
142                    .filter_map(|toolname| async move {
143                        if let Some(tool) = self.tools.get(toolname) {
144                            Some(tool.definition(text.into()).await)
145                        } else {
146                            tracing::warn!(
147                                "Tool implementation not found in toolset: {}",
148                                toolname
149                            );
150                            None
151                        }
152                    })
153                    .collect::<Vec<_>>()
154                    .await;
155
156                completion_request
157                    .documents(dynamic_context)
158                    .tools([static_tools.clone(), dynamic_tools].concat())
159            }
160            None => {
161                let static_tools = stream::iter(self.static_tools.iter())
162                    .filter_map(|toolname| async move {
163                        if let Some(tool) = self.tools.get(toolname) {
164                            // TODO: tool definitions should likely take an `Option<String>`
165                            Some(tool.definition("".into()).await)
166                        } else {
167                            tracing::warn!(
168                                "Tool implementation not found in toolset: {}",
169                                toolname
170                            );
171                            None
172                        }
173                    })
174                    .collect::<Vec<_>>()
175                    .await;
176
177                completion_request.tools(static_tools)
178            }
179        };
180
181        Ok(agent)
182    }
183}
184
185// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
186//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
187//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
188//
189// References:
190//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
191
192#[allow(refining_impl_trait)]
193impl<M: CompletionModel> Prompt for Agent<M> {
194    fn prompt(
195        &self,
196        prompt: impl Into<Message> + Send,
197    ) -> PromptRequest<prompt_request::Standard, M> {
198        PromptRequest::new(self, prompt)
199    }
200}
201
202#[allow(refining_impl_trait)]
203impl<M: CompletionModel> Prompt for &Agent<M> {
204    fn prompt(
205        &self,
206        prompt: impl Into<Message> + Send,
207    ) -> PromptRequest<prompt_request::Standard, M> {
208        PromptRequest::new(*self, prompt)
209    }
210}
211
212#[allow(refining_impl_trait)]
213impl<M: CompletionModel> Chat for Agent<M> {
214    async fn chat(
215        &self,
216        prompt: impl Into<Message> + Send,
217        chat_history: Vec<Message>,
218    ) -> Result<String, PromptError> {
219        let mut cloned_history = chat_history.clone();
220        PromptRequest::new(self, prompt)
221            .with_history(&mut cloned_history)
222            .await
223    }
224}
225
226impl<M: CompletionModel> StreamingCompletion<M> for Agent<M> {
227    async fn stream_completion(
228        &self,
229        prompt: impl Into<Message> + Send,
230        chat_history: Vec<Message>,
231    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
232        // Reuse the existing completion implementation to build the request
233        // This ensures streaming and non-streaming use the same request building logic
234        self.completion(prompt, chat_history).await
235    }
236}
237
238impl<M: CompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
239    async fn stream_prompt(
240        &self,
241        prompt: impl Into<Message> + Send,
242    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
243        self.stream_chat(prompt, vec![]).await
244    }
245}
246
247impl<M: CompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
248    async fn stream_chat(
249        &self,
250        prompt: impl Into<Message> + Send,
251        chat_history: Vec<Message>,
252    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
253        self.stream_completion(prompt, chat_history)
254            .await?
255            .stream()
256            .await
257    }
258}