rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21    RwLock<
22        Vec<(
23            usize,
24            Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25        )>,
26    >,
27>;
28
29/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
30/// (i.e.: system prompt) and a static set of context documents and tools.
31/// All context documents and tools are always provided to the agent when prompted.
32///
33/// # Example
34/// ```
35/// use rig::{completion::Prompt, providers::openai};
36///
37/// let openai = openai::Client::from_env();
38///
39/// let comedian_agent = openai
40///     .agent("gpt-4o")
41///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
42///     .temperature(0.9)
43///     .build();
44///
45/// let response = comedian_agent.prompt("Entertain me!")
46///     .await
47///     .expect("Failed to prompt the agent");
48/// ```
49#[derive(Clone)]
50#[non_exhaustive]
51pub struct Agent<M>
52where
53    M: CompletionModel,
54{
55    /// Name of the agent used for logging and debugging
56    pub name: Option<String>,
57    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
58    pub description: Option<String>,
59    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
60    pub model: Arc<M>,
61    /// System prompt
62    pub preamble: Option<String>,
63    /// Context documents always available to the agent
64    pub static_context: Vec<Document>,
65    /// Temperature of the model
66    pub temperature: Option<f64>,
67    /// Maximum number of tokens for the completion
68    pub max_tokens: Option<u64>,
69    /// Additional parameters to be passed to the model
70    pub additional_params: Option<serde_json::Value>,
71    pub tool_server_handle: ToolServerHandle,
72    /// List of vector store, with the sample number
73    pub dynamic_context: DynamicContextStore,
74    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
75    pub tool_choice: Option<ToolChoice>,
76    /// Default maximum depth for recursive agent calls
77    pub default_max_depth: Option<usize>,
78}
79
80impl<M> Agent<M>
81where
82    M: CompletionModel,
83{
84    /// Returns the name of the agent.
85    pub(crate) fn name(&self) -> &str {
86        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
87    }
88}
89
90impl<M> Completion<M> for Agent<M>
91where
92    M: CompletionModel,
93{
94    async fn completion(
95        &self,
96        prompt: impl Into<Message> + WasmCompatSend,
97        chat_history: Vec<Message>,
98    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
99        let prompt = prompt.into();
100
101        // Find the latest message in the chat history that contains RAG text
102        let rag_text = prompt.rag_text();
103        let rag_text = rag_text.or_else(|| {
104            chat_history
105                .iter()
106                .rev()
107                .find_map(|message| message.rag_text())
108        });
109
110        let completion_request = self
111            .model
112            .completion_request(prompt)
113            .messages(chat_history)
114            .temperature_opt(self.temperature)
115            .max_tokens_opt(self.max_tokens)
116            .additional_params_opt(self.additional_params.clone())
117            .documents(self.static_context.clone());
118        let completion_request = if let Some(preamble) = &self.preamble {
119            completion_request.preamble(preamble.to_owned())
120        } else {
121            completion_request
122        };
123        let completion_request = if let Some(tool_choice) = &self.tool_choice {
124            completion_request.tool_choice(tool_choice.clone())
125        } else {
126            completion_request
127        };
128
129        // If the agent has RAG text, we need to fetch the dynamic context and tools
130        let agent = match &rag_text {
131            Some(text) => {
132                let dynamic_context = stream::iter(self.dynamic_context.read().await.iter())
133                    .then(|(num_sample, index)| async {
134                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
135                        Ok::<_, VectorStoreError>(
136                            index
137                                .top_n(req)
138                                .await?
139                                .into_iter()
140                                .map(|(_, id, doc)| {
141                                    // Pretty print the document if possible for better readability
142                                    let text = serde_json::to_string_pretty(&doc)
143                                        .unwrap_or_else(|_| doc.to_string());
144
145                                    Document {
146                                        id,
147                                        text,
148                                        additional_props: HashMap::new(),
149                                    }
150                                })
151                                .collect::<Vec<_>>(),
152                        )
153                    })
154                    .try_fold(vec![], |mut acc, docs| async {
155                        acc.extend(docs);
156                        Ok(acc)
157                    })
158                    .await
159                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
160
161                let tooldefs = self
162                    .tool_server_handle
163                    .get_tool_defs(Some(text.to_string()))
164                    .await
165                    .map_err(|_| {
166                        CompletionError::RequestError("Failed to get tool definitions".into())
167                    })?;
168
169                completion_request
170                    .documents(dynamic_context)
171                    .tools(tooldefs)
172            }
173            None => {
174                let tooldefs = self
175                    .tool_server_handle
176                    .get_tool_defs(None)
177                    .await
178                    .map_err(|_| {
179                        CompletionError::RequestError("Failed to get tool definitions".into())
180                    })?;
181
182                completion_request.tools(tooldefs)
183            }
184        };
185
186        Ok(agent)
187    }
188}
189
190// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
191//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
192//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
193//
194// References:
195//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
196
197#[allow(refining_impl_trait)]
198impl<M> Prompt for Agent<M>
199where
200    M: CompletionModel,
201{
202    fn prompt(
203        &self,
204        prompt: impl Into<Message> + WasmCompatSend,
205    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
206        PromptRequest::new(self, prompt)
207    }
208}
209
210#[allow(refining_impl_trait)]
211impl<M> Prompt for &Agent<M>
212where
213    M: CompletionModel,
214{
215    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
216    fn prompt(
217        &self,
218        prompt: impl Into<Message> + WasmCompatSend,
219    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
220        PromptRequest::new(*self, prompt)
221    }
222}
223
224#[allow(refining_impl_trait)]
225impl<M> Chat for Agent<M>
226where
227    M: CompletionModel,
228{
229    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
230    async fn chat(
231        &self,
232        prompt: impl Into<Message> + WasmCompatSend,
233        mut chat_history: Vec<Message>,
234    ) -> Result<String, PromptError> {
235        PromptRequest::new(self, prompt)
236            .with_history(&mut chat_history)
237            .await
238    }
239}
240
241impl<M> StreamingCompletion<M> for Agent<M>
242where
243    M: CompletionModel,
244{
245    async fn stream_completion(
246        &self,
247        prompt: impl Into<Message> + WasmCompatSend,
248        chat_history: Vec<Message>,
249    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
250        // Reuse the existing completion implementation to build the request
251        // This ensures streaming and non-streaming use the same request building logic
252        self.completion(prompt, chat_history).await
253    }
254}
255
256impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
257where
258    M: CompletionModel + 'static,
259    M::StreamingResponse: GetTokenUsage,
260{
261    fn stream_prompt(
262        &self,
263        prompt: impl Into<Message> + WasmCompatSend,
264    ) -> StreamingPromptRequest<M, ()> {
265        let arc = Arc::new(self.clone());
266        StreamingPromptRequest::new(arc, prompt)
267    }
268}
269
270impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
271where
272    M: CompletionModel + 'static,
273    M::StreamingResponse: GetTokenUsage,
274{
275    fn stream_chat(
276        &self,
277        prompt: impl Into<Message> + WasmCompatSend,
278        chat_history: Vec<Message>,
279    ) -> StreamingPromptRequest<M, ()> {
280        let arc = Arc::new(self.clone());
281        StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
282    }
283}