rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError,
7    },
8    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
9    tool::ToolSet,
10    vector_store::{VectorStoreError, request::VectorSearchRequest},
11};
12use futures::{StreamExt, TryStreamExt, stream};
13use std::collections::HashMap;
14
15const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
16
17/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
18/// (i.e.: system prompt) and a static set of context documents and tools.
19/// All context documents and tools are always provided to the agent when prompted.
20///
21/// # Example
22/// ```
23/// use rig::{completion::Prompt, providers::openai};
24///
25/// let openai = openai::Client::from_env();
26///
27/// let comedian_agent = openai
28///     .agent("gpt-4o")
29///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
30///     .temperature(0.9)
31///     .build();
32///
33/// let response = comedian_agent.prompt("Entertain me!")
34///     .await
35///     .expect("Failed to prompt the agent");
36/// ```
37#[non_exhaustive]
38pub struct Agent<M: CompletionModel> {
39    /// Name of the agent used for logging and debugging
40    pub name: Option<String>,
41    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
42    pub model: M,
43    /// System prompt
44    pub preamble: String,
45    /// Context documents always available to the agent
46    pub static_context: Vec<Document>,
47    /// Tools that are always available to the agent (identified by their name)
48    pub static_tools: Vec<String>,
49    /// Temperature of the model
50    pub temperature: Option<f64>,
51    /// Maximum number of tokens for the completion
52    pub max_tokens: Option<u64>,
53    /// Additional parameters to be passed to the model
54    pub additional_params: Option<serde_json::Value>,
55    /// List of vector store, with the sample number
56    pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
57    /// Dynamic tools
58    pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
59    /// Actual tool implementations
60    pub tools: ToolSet,
61}
62
63impl<M: CompletionModel> Completion<M> for Agent<M> {
64    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
65    async fn completion(
66        &self,
67        prompt: impl Into<Message> + Send,
68        chat_history: Vec<Message>,
69    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
70        let prompt = prompt.into();
71
72        // Find the latest message in the chat history that contains RAG text
73        let rag_text = prompt.rag_text();
74        let rag_text = rag_text.or_else(|| {
75            chat_history
76                .iter()
77                .rev()
78                .find_map(|message| message.rag_text())
79        });
80
81        let completion_request = self
82            .model
83            .completion_request(prompt)
84            .preamble(self.preamble.clone())
85            .messages(chat_history)
86            .temperature_opt(self.temperature)
87            .max_tokens_opt(self.max_tokens)
88            .additional_params_opt(self.additional_params.clone())
89            .documents(self.static_context.clone());
90
91        // If the agent has RAG text, we need to fetch the dynamic context and tools
92        let agent = match &rag_text {
93            Some(text) => {
94                let dynamic_context = stream::iter(self.dynamic_context.iter())
95                    .then(|(num_sample, index)| async {
96                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
97                        Ok::<_, VectorStoreError>(
98                            index
99                                .top_n(req)
100                                .await?
101                                .into_iter()
102                                .map(|(_, id, doc)| {
103                                    // Pretty print the document if possible for better readability
104                                    let text = serde_json::to_string_pretty(&doc)
105                                        .unwrap_or_else(|_| doc.to_string());
106
107                                    Document {
108                                        id,
109                                        text,
110                                        additional_props: HashMap::new(),
111                                    }
112                                })
113                                .collect::<Vec<_>>(),
114                        )
115                    })
116                    .try_fold(vec![], |mut acc, docs| async {
117                        acc.extend(docs);
118                        Ok(acc)
119                    })
120                    .await
121                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
122
123                let dynamic_tools = stream::iter(self.dynamic_tools.iter())
124                    .then(|(num_sample, index)| async {
125                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
126                        Ok::<_, VectorStoreError>(
127                            index
128                                .top_n_ids(req)
129                                .await?
130                                .into_iter()
131                                .map(|(_, id)| id)
132                                .collect::<Vec<_>>(),
133                        )
134                    })
135                    .try_fold(vec![], |mut acc, docs| async {
136                        for doc in docs {
137                            if let Some(tool) = self.tools.get(&doc) {
138                                acc.push(tool.definition(text.into()).await)
139                            } else {
140                                tracing::warn!("Tool implementation not found in toolset: {}", doc);
141                            }
142                        }
143                        Ok(acc)
144                    })
145                    .await
146                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
147
148                let static_tools = stream::iter(self.static_tools.iter())
149                    .filter_map(|toolname| async move {
150                        if let Some(tool) = self.tools.get(toolname) {
151                            Some(tool.definition(text.into()).await)
152                        } else {
153                            tracing::warn!(
154                                "Tool implementation not found in toolset: {}",
155                                toolname
156                            );
157                            None
158                        }
159                    })
160                    .collect::<Vec<_>>()
161                    .await;
162
163                completion_request
164                    .documents(dynamic_context)
165                    .tools([static_tools.clone(), dynamic_tools].concat())
166            }
167            None => {
168                let static_tools = stream::iter(self.static_tools.iter())
169                    .filter_map(|toolname| async move {
170                        if let Some(tool) = self.tools.get(toolname) {
171                            // TODO: tool definitions should likely take an `Option<String>`
172                            Some(tool.definition("".into()).await)
173                        } else {
174                            tracing::warn!(
175                                "Tool implementation not found in toolset: {}",
176                                toolname
177                            );
178                            None
179                        }
180                    })
181                    .collect::<Vec<_>>()
182                    .await;
183
184                completion_request.tools(static_tools)
185            }
186        };
187
188        Ok(agent)
189    }
190}
191
192// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
193//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
194//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
195//
196// References:
197//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
198
199#[allow(refining_impl_trait)]
200impl<M: CompletionModel> Prompt for Agent<M> {
201    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
202    fn prompt(
203        &self,
204        prompt: impl Into<Message> + Send,
205    ) -> PromptRequest<'_, prompt_request::Standard, M> {
206        PromptRequest::new(self, prompt)
207    }
208}
209
210#[allow(refining_impl_trait)]
211impl<M: CompletionModel> Prompt for &Agent<M> {
212    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
213    fn prompt(
214        &self,
215        prompt: impl Into<Message> + Send,
216    ) -> PromptRequest<'_, prompt_request::Standard, M> {
217        PromptRequest::new(*self, prompt)
218    }
219}
220
221#[allow(refining_impl_trait)]
222impl<M: CompletionModel> Chat for Agent<M> {
223    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
224    async fn chat(
225        &self,
226        prompt: impl Into<Message> + Send,
227        chat_history: Vec<Message>,
228    ) -> Result<String, PromptError> {
229        let mut cloned_history = chat_history.clone();
230        PromptRequest::new(self, prompt)
231            .with_history(&mut cloned_history)
232            .await
233    }
234}
235
236impl<M: CompletionModel> StreamingCompletion<M> for Agent<M> {
237    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
238    async fn stream_completion(
239        &self,
240        prompt: impl Into<Message> + Send,
241        chat_history: Vec<Message>,
242    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
243        // Reuse the existing completion implementation to build the request
244        // This ensures streaming and non-streaming use the same request building logic
245        self.completion(prompt, chat_history).await
246    }
247}
248
249impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
250where
251    M: CompletionModel + 'static,
252    M::StreamingResponse: GetTokenUsage,
253{
254    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
255    fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<'_, M> {
256        StreamingPromptRequest::new(self, prompt)
257    }
258}
259
260impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
261where
262    M: CompletionModel + 'static,
263    M::StreamingResponse: GetTokenUsage,
264{
265    fn stream_chat(
266        &self,
267        prompt: impl Into<Message> + Send,
268        chat_history: Vec<Message>,
269    ) -> StreamingPromptRequest<'_, M> {
270        StreamingPromptRequest::new(self, prompt).with_history(chat_history)
271    }
272}
273
274impl<M: CompletionModel> Agent<M> {
275    /// Returns the name of the agent.
276    pub(crate) fn name(&self) -> &str {
277        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
278    }
279
280    /// Returns the name of the agent as an owned variable.
281    /// Useful in some cases where having the agent name as an owned variable is required.
282    pub(crate) fn name_owned(&self) -> String {
283        self.name.clone().unwrap_or(UNKNOWN_AGENT_NAME.to_string())
284    }
285}