rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError,
7    },
8    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
9    tool::ToolSet,
10    vector_store::{VectorStoreError, request::VectorSearchRequest},
11};
12use futures::{StreamExt, TryStreamExt, stream};
13use std::{collections::HashMap, sync::Arc};
14
15const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
16
17/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
18/// (i.e.: system prompt) and a static set of context documents and tools.
19/// All context documents and tools are always provided to the agent when prompted.
20///
21/// # Example
22/// ```
23/// use rig::{completion::Prompt, providers::openai};
24///
25/// let openai = openai::Client::from_env();
26///
27/// let comedian_agent = openai
28///     .agent("gpt-4o")
29///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
30///     .temperature(0.9)
31///     .build();
32///
33/// let response = comedian_agent.prompt("Entertain me!")
34///     .await
35///     .expect("Failed to prompt the agent");
36/// ```
37#[derive(Clone)]
38#[non_exhaustive]
39pub struct Agent<M>
40where
41    M: CompletionModel,
42{
43    /// Name of the agent used for logging and debugging
44    pub name: Option<String>,
45    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
46    pub model: Arc<M>,
47    /// System prompt
48    pub preamble: Option<String>,
49    /// Context documents always available to the agent
50    pub static_context: Vec<Document>,
51    /// Tools that are always available to the agent (identified by their name)
52    pub static_tools: Vec<String>,
53    /// Temperature of the model
54    pub temperature: Option<f64>,
55    /// Maximum number of tokens for the completion
56    pub max_tokens: Option<u64>,
57    /// Additional parameters to be passed to the model
58    pub additional_params: Option<serde_json::Value>,
59    /// List of vector store, with the sample number
60    pub dynamic_context: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
61    /// Dynamic tools
62    pub dynamic_tools: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
63    /// Actual tool implementations
64    pub tools: Arc<ToolSet>,
65}
66
67impl<M> Agent<M>
68where
69    M: CompletionModel,
70{
71    /// Returns the name of the agent.
72    pub(crate) fn name(&self) -> &str {
73        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
74    }
75
76    /// Returns the name of the agent as an owned variable.
77    /// Useful in some cases where having the agent name as an owned variable is required.
78    pub(crate) fn name_owned(&self) -> String {
79        self.name.clone().unwrap_or(UNKNOWN_AGENT_NAME.to_string())
80    }
81}
82
83impl<M> Completion<M> for Agent<M>
84where
85    M: CompletionModel,
86{
87    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
88    async fn completion(
89        &self,
90        prompt: impl Into<Message> + Send,
91        chat_history: Vec<Message>,
92    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
93        let prompt = prompt.into();
94
95        // Find the latest message in the chat history that contains RAG text
96        let rag_text = prompt.rag_text();
97        let rag_text = rag_text.or_else(|| {
98            chat_history
99                .iter()
100                .rev()
101                .find_map(|message| message.rag_text())
102        });
103
104        let completion_request = self
105            .model
106            .completion_request(prompt)
107            .messages(chat_history)
108            .temperature_opt(self.temperature)
109            .max_tokens_opt(self.max_tokens)
110            .additional_params_opt(self.additional_params.clone())
111            .documents(self.static_context.clone());
112        let completion_request = if let Some(preamble) = &self.preamble {
113            completion_request.preamble(preamble.to_owned())
114        } else {
115            completion_request
116        };
117
118        // If the agent has RAG text, we need to fetch the dynamic context and tools
119        let agent = match &rag_text {
120            Some(text) => {
121                let dynamic_context = stream::iter(self.dynamic_context.iter())
122                    .then(|(num_sample, index)| async {
123                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
124                        Ok::<_, VectorStoreError>(
125                            index
126                                .top_n(req)
127                                .await?
128                                .into_iter()
129                                .map(|(_, id, doc)| {
130                                    // Pretty print the document if possible for better readability
131                                    let text = serde_json::to_string_pretty(&doc)
132                                        .unwrap_or_else(|_| doc.to_string());
133
134                                    Document {
135                                        id,
136                                        text,
137                                        additional_props: HashMap::new(),
138                                    }
139                                })
140                                .collect::<Vec<_>>(),
141                        )
142                    })
143                    .try_fold(vec![], |mut acc, docs| async {
144                        acc.extend(docs);
145                        Ok(acc)
146                    })
147                    .await
148                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
149
150                let dynamic_tools = stream::iter(self.dynamic_tools.iter())
151                    .then(|(num_sample, index)| async {
152                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
153                        Ok::<_, VectorStoreError>(
154                            index
155                                .top_n_ids(req)
156                                .await?
157                                .into_iter()
158                                .map(|(_, id)| id)
159                                .collect::<Vec<_>>(),
160                        )
161                    })
162                    .try_fold(vec![], |mut acc, docs| async {
163                        for doc in docs {
164                            if let Some(tool) = self.tools.get(&doc) {
165                                acc.push(tool.definition(text.into()).await)
166                            } else {
167                                tracing::warn!("Tool implementation not found in toolset: {}", doc);
168                            }
169                        }
170                        Ok(acc)
171                    })
172                    .await
173                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
174
175                let static_tools = stream::iter(self.static_tools.iter())
176                    .filter_map(|toolname| async move {
177                        if let Some(tool) = self.tools.get(toolname) {
178                            Some(tool.definition(text.into()).await)
179                        } else {
180                            tracing::warn!(
181                                "Tool implementation not found in toolset: {}",
182                                toolname
183                            );
184                            None
185                        }
186                    })
187                    .collect::<Vec<_>>()
188                    .await;
189
190                completion_request
191                    .documents(dynamic_context)
192                    .tools([static_tools.clone(), dynamic_tools].concat())
193            }
194            None => {
195                let static_tools = stream::iter(self.static_tools.iter())
196                    .filter_map(|toolname| async move {
197                        if let Some(tool) = self.tools.get(toolname) {
198                            // TODO: tool definitions should likely take an `Option<String>`
199                            Some(tool.definition("".into()).await)
200                        } else {
201                            tracing::warn!(
202                                "Tool implementation not found in toolset: {}",
203                                toolname
204                            );
205                            None
206                        }
207                    })
208                    .collect::<Vec<_>>()
209                    .await;
210
211                completion_request.tools(static_tools)
212            }
213        };
214
215        Ok(agent)
216    }
217}
218
219// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
220//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
221//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
222//
223// References:
224//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
225
226#[allow(refining_impl_trait)]
227impl<M> Prompt for Agent<M>
228where
229    M: CompletionModel,
230{
231    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
232    fn prompt(
233        &self,
234        prompt: impl Into<Message> + Send,
235    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
236        PromptRequest::new(self, prompt)
237    }
238}
239
240#[allow(refining_impl_trait)]
241impl<M> Prompt for &Agent<M>
242where
243    M: CompletionModel,
244{
245    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
246    fn prompt(
247        &self,
248        prompt: impl Into<Message> + Send,
249    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
250        PromptRequest::new(*self, prompt)
251    }
252}
253
254#[allow(refining_impl_trait)]
255impl<M> Chat for Agent<M>
256where
257    M: CompletionModel,
258{
259    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
260    async fn chat(
261        &self,
262        prompt: impl Into<Message> + Send,
263        mut chat_history: Vec<Message>,
264    ) -> Result<String, PromptError> {
265        PromptRequest::new(self, prompt)
266            .with_history(&mut chat_history)
267            .await
268    }
269}
270
271impl<M> StreamingCompletion<M> for Agent<M>
272where
273    M: CompletionModel,
274{
275    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
276    async fn stream_completion(
277        &self,
278        prompt: impl Into<Message> + Send,
279        chat_history: Vec<Message>,
280    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
281        // Reuse the existing completion implementation to build the request
282        // This ensures streaming and non-streaming use the same request building logic
283        self.completion(prompt, chat_history).await
284    }
285}
286
287impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
288where
289    M: CompletionModel + 'static,
290    M::StreamingResponse: GetTokenUsage,
291{
292    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
293    fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<M, ()> {
294        let arc = Arc::new(self.clone());
295        StreamingPromptRequest::new(arc, prompt)
296    }
297}
298
299impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
300where
301    M: CompletionModel + 'static,
302    M::StreamingResponse: GetTokenUsage,
303{
304    fn stream_chat(
305        &self,
306        prompt: impl Into<Message> + Send,
307        chat_history: Vec<Message>,
308    ) -> StreamingPromptRequest<M, ()> {
309        let arc = Arc::new(self.clone());
310        StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
311    }
312}