rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError,
7    },
8    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
9    tool::ToolSet,
10    vector_store::{VectorStoreError, request::VectorSearchRequest},
11};
12use futures::{StreamExt, TryStreamExt, stream};
13use std::{collections::HashMap, sync::Arc};
14
15const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
16
17/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
18/// (i.e.: system prompt) and a static set of context documents and tools.
19/// All context documents and tools are always provided to the agent when prompted.
20///
21/// # Example
22/// ```
23/// use rig::{completion::Prompt, providers::openai};
24///
25/// let openai = openai::Client::from_env();
26///
27/// let comedian_agent = openai
28///     .agent("gpt-4o")
29///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
30///     .temperature(0.9)
31///     .build();
32///
33/// let response = comedian_agent.prompt("Entertain me!")
34///     .await
35///     .expect("Failed to prompt the agent");
36/// ```
37#[derive(Clone)]
38#[non_exhaustive]
39pub struct Agent<M>
40where
41    M: CompletionModel,
42{
43    /// Name of the agent used for logging and debugging
44    pub name: Option<String>,
45    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
46    pub model: Arc<M>,
47    /// System prompt
48    pub preamble: String,
49    /// Context documents always available to the agent
50    pub static_context: Vec<Document>,
51    /// Tools that are always available to the agent (identified by their name)
52    pub static_tools: Vec<String>,
53    /// Temperature of the model
54    pub temperature: Option<f64>,
55    /// Maximum number of tokens for the completion
56    pub max_tokens: Option<u64>,
57    /// Additional parameters to be passed to the model
58    pub additional_params: Option<serde_json::Value>,
59    /// List of vector store, with the sample number
60    pub dynamic_context: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
61    /// Dynamic tools
62    pub dynamic_tools: Arc<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>,
63    /// Actual tool implementations
64    pub tools: Arc<ToolSet>,
65}
66
67impl<M> Agent<M>
68where
69    M: CompletionModel,
70{
71    /// Returns the name of the agent.
72    pub(crate) fn name(&self) -> &str {
73        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
74    }
75
76    /// Returns the name of the agent as an owned variable.
77    /// Useful in some cases where having the agent name as an owned variable is required.
78    pub(crate) fn name_owned(&self) -> String {
79        self.name.clone().unwrap_or(UNKNOWN_AGENT_NAME.to_string())
80    }
81}
82
83impl<M> Completion<M> for Agent<M>
84where
85    M: CompletionModel,
86{
87    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
88    async fn completion(
89        &self,
90        prompt: impl Into<Message> + Send,
91        chat_history: Vec<Message>,
92    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
93        let prompt = prompt.into();
94
95        // Find the latest message in the chat history that contains RAG text
96        let rag_text = prompt.rag_text();
97        let rag_text = rag_text.or_else(|| {
98            chat_history
99                .iter()
100                .rev()
101                .find_map(|message| message.rag_text())
102        });
103
104        let completion_request = self
105            .model
106            .completion_request(prompt)
107            .preamble(self.preamble.clone())
108            .messages(chat_history)
109            .temperature_opt(self.temperature)
110            .max_tokens_opt(self.max_tokens)
111            .additional_params_opt(self.additional_params.clone())
112            .documents(self.static_context.clone());
113
114        // If the agent has RAG text, we need to fetch the dynamic context and tools
115        let agent = match &rag_text {
116            Some(text) => {
117                let dynamic_context = stream::iter(self.dynamic_context.iter())
118                    .then(|(num_sample, index)| async {
119                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
120                        Ok::<_, VectorStoreError>(
121                            index
122                                .top_n(req)
123                                .await?
124                                .into_iter()
125                                .map(|(_, id, doc)| {
126                                    // Pretty print the document if possible for better readability
127                                    let text = serde_json::to_string_pretty(&doc)
128                                        .unwrap_or_else(|_| doc.to_string());
129
130                                    Document {
131                                        id,
132                                        text,
133                                        additional_props: HashMap::new(),
134                                    }
135                                })
136                                .collect::<Vec<_>>(),
137                        )
138                    })
139                    .try_fold(vec![], |mut acc, docs| async {
140                        acc.extend(docs);
141                        Ok(acc)
142                    })
143                    .await
144                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
145
146                let dynamic_tools = stream::iter(self.dynamic_tools.iter())
147                    .then(|(num_sample, index)| async {
148                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
149                        Ok::<_, VectorStoreError>(
150                            index
151                                .top_n_ids(req)
152                                .await?
153                                .into_iter()
154                                .map(|(_, id)| id)
155                                .collect::<Vec<_>>(),
156                        )
157                    })
158                    .try_fold(vec![], |mut acc, docs| async {
159                        for doc in docs {
160                            if let Some(tool) = self.tools.get(&doc) {
161                                acc.push(tool.definition(text.into()).await)
162                            } else {
163                                tracing::warn!("Tool implementation not found in toolset: {}", doc);
164                            }
165                        }
166                        Ok(acc)
167                    })
168                    .await
169                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
170
171                let static_tools = stream::iter(self.static_tools.iter())
172                    .filter_map(|toolname| async move {
173                        if let Some(tool) = self.tools.get(toolname) {
174                            Some(tool.definition(text.into()).await)
175                        } else {
176                            tracing::warn!(
177                                "Tool implementation not found in toolset: {}",
178                                toolname
179                            );
180                            None
181                        }
182                    })
183                    .collect::<Vec<_>>()
184                    .await;
185
186                completion_request
187                    .documents(dynamic_context)
188                    .tools([static_tools.clone(), dynamic_tools].concat())
189            }
190            None => {
191                let static_tools = stream::iter(self.static_tools.iter())
192                    .filter_map(|toolname| async move {
193                        if let Some(tool) = self.tools.get(toolname) {
194                            // TODO: tool definitions should likely take an `Option<String>`
195                            Some(tool.definition("".into()).await)
196                        } else {
197                            tracing::warn!(
198                                "Tool implementation not found in toolset: {}",
199                                toolname
200                            );
201                            None
202                        }
203                    })
204                    .collect::<Vec<_>>()
205                    .await;
206
207                completion_request.tools(static_tools)
208            }
209        };
210
211        Ok(agent)
212    }
213}
214
215// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
216//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
217//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
218//
219// References:
220//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
221
222#[allow(refining_impl_trait)]
223impl<M> Prompt for Agent<M>
224where
225    M: CompletionModel,
226{
227    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
228    fn prompt(
229        &self,
230        prompt: impl Into<Message> + Send,
231    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
232        PromptRequest::new(self, prompt)
233    }
234}
235
236#[allow(refining_impl_trait)]
237impl<M> Prompt for &Agent<M>
238where
239    M: CompletionModel,
240{
241    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
242    fn prompt(
243        &self,
244        prompt: impl Into<Message> + Send,
245    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
246        PromptRequest::new(*self, prompt)
247    }
248}
249
250#[allow(refining_impl_trait)]
251impl<M> Chat for Agent<M>
252where
253    M: CompletionModel,
254{
255    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
256    async fn chat(
257        &self,
258        prompt: impl Into<Message> + Send,
259        chat_history: Vec<Message>,
260    ) -> Result<String, PromptError> {
261        let mut cloned_history = chat_history.clone();
262        PromptRequest::new(self, prompt)
263            .with_history(&mut cloned_history)
264            .await
265    }
266}
267
268impl<M> StreamingCompletion<M> for Agent<M>
269where
270    M: CompletionModel,
271{
272    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
273    async fn stream_completion(
274        &self,
275        prompt: impl Into<Message> + Send,
276        chat_history: Vec<Message>,
277    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
278        // Reuse the existing completion implementation to build the request
279        // This ensures streaming and non-streaming use the same request building logic
280        self.completion(prompt, chat_history).await
281    }
282}
283
284impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
285where
286    M: CompletionModel + 'static,
287    M::StreamingResponse: GetTokenUsage,
288{
289    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
290    fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<M, ()> {
291        let arc = Arc::new(self.clone());
292        StreamingPromptRequest::new(arc, prompt)
293    }
294}
295
296impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
297where
298    M: CompletionModel + 'static,
299    M::StreamingResponse: GetTokenUsage,
300{
301    fn stream_chat(
302        &self,
303        prompt: impl Into<Message> + Send,
304        chat_history: Vec<Message>,
305    ) -> StreamingPromptRequest<M, ()> {
306        let arc = Arc::new(self.clone());
307        StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
308    }
309}