rig/agent/
completion.rs

1use std::collections::HashMap;
2
3use futures::{stream, StreamExt, TryStreamExt};
4
5use crate::{
6    completion::{
7        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
8        Message, Prompt, PromptError,
9    },
10    streaming::{
11        StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
12        StreamingResult,
13    },
14    tool::ToolSet,
15    vector_store::VectorStoreError,
16};
17
18use super::prompt_request::PromptRequest;
19
20/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
21/// (i.e.: system prompt) and a static set of context documents and tools.
22/// All context documents and tools are always provided to the agent when prompted.
23///
24/// # Example
25/// ```
26/// use rig::{completion::Prompt, providers::openai};
27///
28/// let openai = openai::Client::from_env();
29///
30/// let comedian_agent = openai
31///     .agent("gpt-4o")
32///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
33///     .temperature(0.9)
34///     .build();
35///
36/// let response = comedian_agent.prompt("Entertain me!")
37///     .await
38///     .expect("Failed to prompt the agent");
39/// ```
40pub struct Agent<M: CompletionModel> {
41    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
42    pub model: M,
43    /// System prompt
44    pub preamble: String,
45    /// Context documents always available to the agent
46    pub static_context: Vec<Document>,
47    /// Tools that are always available to the agent (identified by their name)
48    pub static_tools: Vec<String>,
49    /// Temperature of the model
50    pub temperature: Option<f64>,
51    /// Maximum number of tokens for the completion
52    pub max_tokens: Option<u64>,
53    /// Additional parameters to be passed to the model
54    pub additional_params: Option<serde_json::Value>,
55    /// List of vector store, with the sample number
56    pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
57    /// Dynamic tools
58    pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
59    /// Actual tool implementations
60    pub tools: ToolSet,
61}
62
63impl<M: CompletionModel> Completion<M> for Agent<M> {
64    async fn completion(
65        &self,
66        prompt: impl Into<Message> + Send,
67        chat_history: Vec<Message>,
68    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
69        let prompt = prompt.into();
70
71        // Find the latest message in the chat history that contains RAG text
72        let rag_text = prompt.rag_text();
73        let rag_text = rag_text.or_else(|| {
74            chat_history
75                .iter()
76                .rev()
77                .find_map(|message| message.rag_text())
78        });
79
80        let completion_request = self
81            .model
82            .completion_request(prompt)
83            .preamble(self.preamble.clone())
84            .messages(chat_history)
85            .temperature_opt(self.temperature)
86            .max_tokens_opt(self.max_tokens)
87            .additional_params_opt(self.additional_params.clone())
88            .documents(self.static_context.clone());
89
90        // If the agent has RAG text, we need to fetch the dynamic context and tools
91        let agent = match &rag_text {
92            Some(text) => {
93                let dynamic_context = stream::iter(self.dynamic_context.iter())
94                    .then(|(num_sample, index)| async {
95                        Ok::<_, VectorStoreError>(
96                            index
97                                .top_n(text, *num_sample)
98                                .await?
99                                .into_iter()
100                                .map(|(_, id, doc)| {
101                                    // Pretty print the document if possible for better readability
102                                    let text = serde_json::to_string_pretty(&doc)
103                                        .unwrap_or_else(|_| doc.to_string());
104
105                                    Document {
106                                        id,
107                                        text,
108                                        additional_props: HashMap::new(),
109                                    }
110                                })
111                                .collect::<Vec<_>>(),
112                        )
113                    })
114                    .try_fold(vec![], |mut acc, docs| async {
115                        acc.extend(docs);
116                        Ok(acc)
117                    })
118                    .await
119                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
120
121                let dynamic_tools = stream::iter(self.dynamic_tools.iter())
122                    .then(|(num_sample, index)| async {
123                        Ok::<_, VectorStoreError>(
124                            index
125                                .top_n_ids(text, *num_sample)
126                                .await?
127                                .into_iter()
128                                .map(|(_, id)| id)
129                                .collect::<Vec<_>>(),
130                        )
131                    })
132                    .try_fold(vec![], |mut acc, docs| async {
133                        for doc in docs {
134                            if let Some(tool) = self.tools.get(&doc) {
135                                acc.push(tool.definition(text.into()).await)
136                            } else {
137                                tracing::warn!("Tool implementation not found in toolset: {}", doc);
138                            }
139                        }
140                        Ok(acc)
141                    })
142                    .await
143                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
144
145                let static_tools = stream::iter(self.static_tools.iter())
146                    .filter_map(|toolname| async move {
147                        if let Some(tool) = self.tools.get(toolname) {
148                            Some(tool.definition(text.into()).await)
149                        } else {
150                            tracing::warn!(
151                                "Tool implementation not found in toolset: {}",
152                                toolname
153                            );
154                            None
155                        }
156                    })
157                    .collect::<Vec<_>>()
158                    .await;
159
160                completion_request
161                    .documents(dynamic_context)
162                    .tools([static_tools.clone(), dynamic_tools].concat())
163            }
164            None => {
165                let static_tools = stream::iter(self.static_tools.iter())
166                    .filter_map(|toolname| async move {
167                        if let Some(tool) = self.tools.get(toolname) {
168                            // TODO: tool definitions should likely take an `Option<String>`
169                            Some(tool.definition("".into()).await)
170                        } else {
171                            tracing::warn!(
172                                "Tool implementation not found in toolset: {}",
173                                toolname
174                            );
175                            None
176                        }
177                    })
178                    .collect::<Vec<_>>()
179                    .await;
180
181                completion_request.tools(static_tools)
182            }
183        };
184
185        Ok(agent)
186    }
187}
188
189// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
190//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
191//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
192//
193// References:
194//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
195
196#[allow(refining_impl_trait)]
197impl<M: CompletionModel> Prompt for Agent<M> {
198    fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
199        PromptRequest::new(self, prompt)
200    }
201}
202
203#[allow(refining_impl_trait)]
204impl<M: CompletionModel> Prompt for &Agent<M> {
205    fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
206        PromptRequest::new(*self, prompt)
207    }
208}
209
210#[allow(refining_impl_trait)]
211impl<M: CompletionModel> Chat for Agent<M> {
212    async fn chat(
213        &self,
214        prompt: impl Into<Message> + Send,
215        chat_history: Vec<Message>,
216    ) -> Result<String, PromptError> {
217        let mut cloned_history = chat_history.clone();
218        PromptRequest::new(self, prompt)
219            .with_history(&mut cloned_history)
220            .await
221    }
222}
223
224impl<M: StreamingCompletionModel> StreamingCompletion<M> for Agent<M> {
225    async fn stream_completion(
226        &self,
227        prompt: impl Into<Message> + Send,
228        chat_history: Vec<Message>,
229    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
230        // Reuse the existing completion implementation to build the request
231        // This ensures streaming and non-streaming use the same request building logic
232        self.completion(prompt, chat_history).await
233    }
234}
235
236impl<M: StreamingCompletionModel> StreamingPrompt for Agent<M> {
237    async fn stream_prompt(&self, prompt: &str) -> Result<StreamingResult, CompletionError> {
238        self.stream_chat(prompt, vec![]).await
239    }
240}
241
242impl<M: StreamingCompletionModel> StreamingChat for Agent<M> {
243    async fn stream_chat(
244        &self,
245        prompt: &str,
246        chat_history: Vec<Message>,
247    ) -> Result<StreamingResult, CompletionError> {
248        self.stream_completion(prompt, chat_history)
249            .await?
250            .stream()
251            .await
252    }
253}