rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore =
21    Arc<RwLock<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>>;
22
23/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
24/// (i.e.: system prompt) and a static set of context documents and tools.
25/// All context documents and tools are always provided to the agent when prompted.
26///
27/// # Example
28/// ```
29/// use rig::{completion::Prompt, providers::openai};
30///
31/// let openai = openai::Client::from_env();
32///
33/// let comedian_agent = openai
34///     .agent("gpt-4o")
35///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
36///     .temperature(0.9)
37///     .build();
38///
39/// let response = comedian_agent.prompt("Entertain me!")
40///     .await
41///     .expect("Failed to prompt the agent");
42/// ```
43#[derive(Clone)]
44#[non_exhaustive]
45pub struct Agent<M>
46where
47    M: CompletionModel,
48{
49    /// Name of the agent used for logging and debugging
50    pub name: Option<String>,
51    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
52    pub description: Option<String>,
53    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
54    pub model: Arc<M>,
55    /// System prompt
56    pub preamble: Option<String>,
57    /// Context documents always available to the agent
58    pub static_context: Vec<Document>,
59    /// Temperature of the model
60    pub temperature: Option<f64>,
61    /// Maximum number of tokens for the completion
62    pub max_tokens: Option<u64>,
63    /// Additional parameters to be passed to the model
64    pub additional_params: Option<serde_json::Value>,
65    pub tool_server_handle: ToolServerHandle,
66    /// List of vector store, with the sample number
67    pub dynamic_context: DynamicContextStore,
68    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
69    pub tool_choice: Option<ToolChoice>,
70}
71
72impl<M> Agent<M>
73where
74    M: CompletionModel,
75{
76    /// Returns the name of the agent.
77    pub(crate) fn name(&self) -> &str {
78        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
79    }
80}
81
82impl<M> Completion<M> for Agent<M>
83where
84    M: CompletionModel,
85{
86    async fn completion(
87        &self,
88        prompt: impl Into<Message> + WasmCompatSend,
89        chat_history: Vec<Message>,
90    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
91        let prompt = prompt.into();
92
93        // Find the latest message in the chat history that contains RAG text
94        let rag_text = prompt.rag_text();
95        let rag_text = rag_text.or_else(|| {
96            chat_history
97                .iter()
98                .rev()
99                .find_map(|message| message.rag_text())
100        });
101
102        let completion_request = self
103            .model
104            .completion_request(prompt)
105            .messages(chat_history)
106            .temperature_opt(self.temperature)
107            .max_tokens_opt(self.max_tokens)
108            .additional_params_opt(self.additional_params.clone())
109            .documents(self.static_context.clone());
110        let completion_request = if let Some(preamble) = &self.preamble {
111            completion_request.preamble(preamble.to_owned())
112        } else {
113            completion_request
114        };
115
116        // If the agent has RAG text, we need to fetch the dynamic context and tools
117        let agent = match &rag_text {
118            Some(text) => {
119                let dynamic_context = stream::iter(self.dynamic_context.read().await.iter())
120                    .then(|(num_sample, index)| async {
121                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
122                        Ok::<_, VectorStoreError>(
123                            index
124                                .top_n(req)
125                                .await?
126                                .into_iter()
127                                .map(|(_, id, doc)| {
128                                    // Pretty print the document if possible for better readability
129                                    let text = serde_json::to_string_pretty(&doc)
130                                        .unwrap_or_else(|_| doc.to_string());
131
132                                    Document {
133                                        id,
134                                        text,
135                                        additional_props: HashMap::new(),
136                                    }
137                                })
138                                .collect::<Vec<_>>(),
139                        )
140                    })
141                    .try_fold(vec![], |mut acc, docs| async {
142                        acc.extend(docs);
143                        Ok(acc)
144                    })
145                    .await
146                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
147
148                let tooldefs = self
149                    .tool_server_handle
150                    .get_tool_defs(Some(text.to_string()))
151                    .await
152                    .map_err(|_| {
153                        CompletionError::RequestError("Failed to get tool definitions".into())
154                    })?;
155
156                completion_request
157                    .documents(dynamic_context)
158                    .tools(tooldefs)
159            }
160            None => {
161                let tooldefs = self
162                    .tool_server_handle
163                    .get_tool_defs(None)
164                    .await
165                    .map_err(|_| {
166                        CompletionError::RequestError("Failed to get tool definitions".into())
167                    })?;
168
169                completion_request.tools(tooldefs)
170            }
171        };
172
173        Ok(agent)
174    }
175}
176
177// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
178//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
179//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
180//
181// References:
182//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
183
184#[allow(refining_impl_trait)]
185impl<M> Prompt for Agent<M>
186where
187    M: CompletionModel,
188{
189    fn prompt(
190        &self,
191        prompt: impl Into<Message> + WasmCompatSend,
192    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
193        PromptRequest::new(self, prompt)
194    }
195}
196
197#[allow(refining_impl_trait)]
198impl<M> Prompt for &Agent<M>
199where
200    M: CompletionModel,
201{
202    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
203    fn prompt(
204        &self,
205        prompt: impl Into<Message> + WasmCompatSend,
206    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
207        PromptRequest::new(*self, prompt)
208    }
209}
210
211#[allow(refining_impl_trait)]
212impl<M> Chat for Agent<M>
213where
214    M: CompletionModel,
215{
216    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
217    async fn chat(
218        &self,
219        prompt: impl Into<Message> + WasmCompatSend,
220        mut chat_history: Vec<Message>,
221    ) -> Result<String, PromptError> {
222        PromptRequest::new(self, prompt)
223            .with_history(&mut chat_history)
224            .await
225    }
226}
227
228impl<M> StreamingCompletion<M> for Agent<M>
229where
230    M: CompletionModel,
231{
232    async fn stream_completion(
233        &self,
234        prompt: impl Into<Message> + WasmCompatSend,
235        chat_history: Vec<Message>,
236    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
237        // Reuse the existing completion implementation to build the request
238        // This ensures streaming and non-streaming use the same request building logic
239        self.completion(prompt, chat_history).await
240    }
241}
242
243impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
244where
245    M: CompletionModel + 'static,
246    M::StreamingResponse: GetTokenUsage,
247{
248    fn stream_prompt(
249        &self,
250        prompt: impl Into<Message> + WasmCompatSend,
251    ) -> StreamingPromptRequest<M, ()> {
252        let arc = Arc::new(self.clone());
253        StreamingPromptRequest::new(arc, prompt)
254    }
255}
256
257impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
258where
259    M: CompletionModel + 'static,
260    M::StreamingResponse: GetTokenUsage,
261{
262    fn stream_chat(
263        &self,
264        prompt: impl Into<Message> + WasmCompatSend,
265        chat_history: Vec<Message>,
266    ) -> StreamingPromptRequest<M, ()> {
267        let arc = Arc::new(self.clone());
268        StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
269    }
270}