rig/agent/
completion.rs

1use super::prompt_request::PromptRequest;
2use crate::{
3    completion::{
4        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
5        Message, Prompt, PromptError,
6    },
7    streaming::{StreamingChat, StreamingCompletion, StreamingCompletionResponse, StreamingPrompt},
8    tool::ToolSet,
9    vector_store::VectorStoreError,
10};
11use futures::{StreamExt, TryStreamExt, stream};
12use std::collections::HashMap;
13
14/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
15/// (i.e.: system prompt) and a static set of context documents and tools.
16/// All context documents and tools are always provided to the agent when prompted.
17///
18/// # Example
19/// ```
20/// use rig::{completion::Prompt, providers::openai};
21///
22/// let openai = openai::Client::from_env();
23///
24/// let comedian_agent = openai
25///     .agent("gpt-4o")
26///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
27///     .temperature(0.9)
28///     .build();
29///
30/// let response = comedian_agent.prompt("Entertain me!")
31///     .await
32///     .expect("Failed to prompt the agent");
33/// ```
34pub struct Agent<M: CompletionModel> {
35    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
36    pub model: M,
37    /// System prompt
38    pub preamble: String,
39    /// Context documents always available to the agent
40    pub static_context: Vec<Document>,
41    /// Tools that are always available to the agent (identified by their name)
42    pub static_tools: Vec<String>,
43    /// Temperature of the model
44    pub temperature: Option<f64>,
45    /// Maximum number of tokens for the completion
46    pub max_tokens: Option<u64>,
47    /// Additional parameters to be passed to the model
48    pub additional_params: Option<serde_json::Value>,
49    /// List of vector store, with the sample number
50    pub dynamic_context: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
51    /// Dynamic tools
52    pub dynamic_tools: Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>,
53    /// Actual tool implementations
54    pub tools: ToolSet,
55}
56
57impl<M: CompletionModel> Completion<M> for Agent<M> {
58    async fn completion(
59        &self,
60        prompt: impl Into<Message> + Send,
61        chat_history: Vec<Message>,
62    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
63        let prompt = prompt.into();
64
65        // Find the latest message in the chat history that contains RAG text
66        let rag_text = prompt.rag_text();
67        let rag_text = rag_text.or_else(|| {
68            chat_history
69                .iter()
70                .rev()
71                .find_map(|message| message.rag_text())
72        });
73
74        let completion_request = self
75            .model
76            .completion_request(prompt)
77            .preamble(self.preamble.clone())
78            .messages(chat_history)
79            .temperature_opt(self.temperature)
80            .max_tokens_opt(self.max_tokens)
81            .additional_params_opt(self.additional_params.clone())
82            .documents(self.static_context.clone());
83
84        // If the agent has RAG text, we need to fetch the dynamic context and tools
85        let agent = match &rag_text {
86            Some(text) => {
87                let dynamic_context = stream::iter(self.dynamic_context.iter())
88                    .then(|(num_sample, index)| async {
89                        Ok::<_, VectorStoreError>(
90                            index
91                                .top_n(text, *num_sample)
92                                .await?
93                                .into_iter()
94                                .map(|(_, id, doc)| {
95                                    // Pretty print the document if possible for better readability
96                                    let text = serde_json::to_string_pretty(&doc)
97                                        .unwrap_or_else(|_| doc.to_string());
98
99                                    Document {
100                                        id,
101                                        text,
102                                        additional_props: HashMap::new(),
103                                    }
104                                })
105                                .collect::<Vec<_>>(),
106                        )
107                    })
108                    .try_fold(vec![], |mut acc, docs| async {
109                        acc.extend(docs);
110                        Ok(acc)
111                    })
112                    .await
113                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
114
115                let dynamic_tools = stream::iter(self.dynamic_tools.iter())
116                    .then(|(num_sample, index)| async {
117                        Ok::<_, VectorStoreError>(
118                            index
119                                .top_n_ids(text, *num_sample)
120                                .await?
121                                .into_iter()
122                                .map(|(_, id)| id)
123                                .collect::<Vec<_>>(),
124                        )
125                    })
126                    .try_fold(vec![], |mut acc, docs| async {
127                        for doc in docs {
128                            if let Some(tool) = self.tools.get(&doc) {
129                                acc.push(tool.definition(text.into()).await)
130                            } else {
131                                tracing::warn!("Tool implementation not found in toolset: {}", doc);
132                            }
133                        }
134                        Ok(acc)
135                    })
136                    .await
137                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
138
139                let static_tools = stream::iter(self.static_tools.iter())
140                    .filter_map(|toolname| async move {
141                        if let Some(tool) = self.tools.get(toolname) {
142                            Some(tool.definition(text.into()).await)
143                        } else {
144                            tracing::warn!(
145                                "Tool implementation not found in toolset: {}",
146                                toolname
147                            );
148                            None
149                        }
150                    })
151                    .collect::<Vec<_>>()
152                    .await;
153
154                completion_request
155                    .documents(dynamic_context)
156                    .tools([static_tools.clone(), dynamic_tools].concat())
157            }
158            None => {
159                let static_tools = stream::iter(self.static_tools.iter())
160                    .filter_map(|toolname| async move {
161                        if let Some(tool) = self.tools.get(toolname) {
162                            // TODO: tool definitions should likely take an `Option<String>`
163                            Some(tool.definition("".into()).await)
164                        } else {
165                            tracing::warn!(
166                                "Tool implementation not found in toolset: {}",
167                                toolname
168                            );
169                            None
170                        }
171                    })
172                    .collect::<Vec<_>>()
173                    .await;
174
175                completion_request.tools(static_tools)
176            }
177        };
178
179        Ok(agent)
180    }
181}
182
183// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
184//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
185//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
186//
187// References:
188//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
189
190#[allow(refining_impl_trait)]
191impl<M: CompletionModel> Prompt for Agent<M> {
192    fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
193        PromptRequest::new(self, prompt)
194    }
195}
196
197#[allow(refining_impl_trait)]
198impl<M: CompletionModel> Prompt for &Agent<M> {
199    fn prompt(&self, prompt: impl Into<Message> + Send) -> PromptRequest<M> {
200        PromptRequest::new(*self, prompt)
201    }
202}
203
204#[allow(refining_impl_trait)]
205impl<M: CompletionModel> Chat for Agent<M> {
206    async fn chat(
207        &self,
208        prompt: impl Into<Message> + Send,
209        chat_history: Vec<Message>,
210    ) -> Result<String, PromptError> {
211        let mut cloned_history = chat_history.clone();
212        PromptRequest::new(self, prompt)
213            .with_history(&mut cloned_history)
214            .await
215    }
216}
217
218impl<M: CompletionModel> StreamingCompletion<M> for Agent<M> {
219    async fn stream_completion(
220        &self,
221        prompt: impl Into<Message> + Send,
222        chat_history: Vec<Message>,
223    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
224        // Reuse the existing completion implementation to build the request
225        // This ensures streaming and non-streaming use the same request building logic
226        self.completion(prompt, chat_history).await
227    }
228}
229
230impl<M: CompletionModel> StreamingPrompt<M::StreamingResponse> for Agent<M> {
231    async fn stream_prompt(
232        &self,
233        prompt: impl Into<Message> + Send,
234    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
235        self.stream_chat(prompt, vec![]).await
236    }
237}
238
239impl<M: CompletionModel> StreamingChat<M::StreamingResponse> for Agent<M> {
240    async fn stream_chat(
241        &self,
242        prompt: impl Into<Message> + Send,
243        chat_history: Vec<Message>,
244    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError> {
245        self.stream_completion(prompt, chat_history)
246            .await?
247            .stream()
248            .await
249    }
250}