rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21    RwLock<
22        Vec<(
23            usize,
24            Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25        )>,
26    >,
27>;
28
29/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
30/// (i.e.: system prompt) and a static set of context documents and tools.
31/// All context documents and tools are always provided to the agent when prompted.
32///
33/// # Example
34/// ```
35/// use rig::{completion::Prompt, providers::openai};
36///
37/// let openai = openai::Client::from_env();
38///
39/// let comedian_agent = openai
40///     .agent("gpt-4o")
41///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
42///     .temperature(0.9)
43///     .build();
44///
45/// let response = comedian_agent.prompt("Entertain me!")
46///     .await
47///     .expect("Failed to prompt the agent");
48/// ```
49#[derive(Clone)]
50#[non_exhaustive]
51pub struct Agent<M>
52where
53    M: CompletionModel,
54{
55    /// Name of the agent used for logging and debugging
56    pub name: Option<String>,
57    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
58    pub description: Option<String>,
59    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
60    pub model: Arc<M>,
61    /// System prompt
62    pub preamble: Option<String>,
63    /// Context documents always available to the agent
64    pub static_context: Vec<Document>,
65    /// Temperature of the model
66    pub temperature: Option<f64>,
67    /// Maximum number of tokens for the completion
68    pub max_tokens: Option<u64>,
69    /// Additional parameters to be passed to the model
70    pub additional_params: Option<serde_json::Value>,
71    pub tool_server_handle: ToolServerHandle,
72    /// List of vector store, with the sample number
73    pub dynamic_context: DynamicContextStore,
74    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
75    pub tool_choice: Option<ToolChoice>,
76}
77
78impl<M> Agent<M>
79where
80    M: CompletionModel,
81{
82    /// Returns the name of the agent.
83    pub(crate) fn name(&self) -> &str {
84        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
85    }
86}
87
88impl<M> Completion<M> for Agent<M>
89where
90    M: CompletionModel,
91{
92    async fn completion(
93        &self,
94        prompt: impl Into<Message> + WasmCompatSend,
95        chat_history: Vec<Message>,
96    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
97        let prompt = prompt.into();
98
99        // Find the latest message in the chat history that contains RAG text
100        let rag_text = prompt.rag_text();
101        let rag_text = rag_text.or_else(|| {
102            chat_history
103                .iter()
104                .rev()
105                .find_map(|message| message.rag_text())
106        });
107
108        let completion_request = self
109            .model
110            .completion_request(prompt)
111            .messages(chat_history)
112            .temperature_opt(self.temperature)
113            .max_tokens_opt(self.max_tokens)
114            .additional_params_opt(self.additional_params.clone())
115            .documents(self.static_context.clone());
116        let completion_request = if let Some(preamble) = &self.preamble {
117            completion_request.preamble(preamble.to_owned())
118        } else {
119            completion_request
120        };
121        let completion_request = if let Some(tool_choice) = &self.tool_choice {
122            completion_request.tool_choice(tool_choice.clone())
123        } else {
124            completion_request
125        };
126
127        // If the agent has RAG text, we need to fetch the dynamic context and tools
128        let agent = match &rag_text {
129            Some(text) => {
130                let dynamic_context = stream::iter(self.dynamic_context.read().await.iter())
131                    .then(|(num_sample, index)| async {
132                        let req = VectorSearchRequest::builder().query(text).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
133                        Ok::<_, VectorStoreError>(
134                            index
135                                .top_n(req)
136                                .await?
137                                .into_iter()
138                                .map(|(_, id, doc)| {
139                                    // Pretty print the document if possible for better readability
140                                    let text = serde_json::to_string_pretty(&doc)
141                                        .unwrap_or_else(|_| doc.to_string());
142
143                                    Document {
144                                        id,
145                                        text,
146                                        additional_props: HashMap::new(),
147                                    }
148                                })
149                                .collect::<Vec<_>>(),
150                        )
151                    })
152                    .try_fold(vec![], |mut acc, docs| async {
153                        acc.extend(docs);
154                        Ok(acc)
155                    })
156                    .await
157                    .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
158
159                let tooldefs = self
160                    .tool_server_handle
161                    .get_tool_defs(Some(text.to_string()))
162                    .await
163                    .map_err(|_| {
164                        CompletionError::RequestError("Failed to get tool definitions".into())
165                    })?;
166
167                completion_request
168                    .documents(dynamic_context)
169                    .tools(tooldefs)
170            }
171            None => {
172                let tooldefs = self
173                    .tool_server_handle
174                    .get_tool_defs(None)
175                    .await
176                    .map_err(|_| {
177                        CompletionError::RequestError("Failed to get tool definitions".into())
178                    })?;
179
180                completion_request.tools(tooldefs)
181            }
182        };
183
184        Ok(agent)
185    }
186}
187
188// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
189//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
190//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
191//
192// References:
193//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
194
195#[allow(refining_impl_trait)]
196impl<M> Prompt for Agent<M>
197where
198    M: CompletionModel,
199{
200    fn prompt(
201        &self,
202        prompt: impl Into<Message> + WasmCompatSend,
203    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
204        PromptRequest::new(self, prompt)
205    }
206}
207
208#[allow(refining_impl_trait)]
209impl<M> Prompt for &Agent<M>
210where
211    M: CompletionModel,
212{
213    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
214    fn prompt(
215        &self,
216        prompt: impl Into<Message> + WasmCompatSend,
217    ) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
218        PromptRequest::new(*self, prompt)
219    }
220}
221
222#[allow(refining_impl_trait)]
223impl<M> Chat for Agent<M>
224where
225    M: CompletionModel,
226{
227    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
228    async fn chat(
229        &self,
230        prompt: impl Into<Message> + WasmCompatSend,
231        mut chat_history: Vec<Message>,
232    ) -> Result<String, PromptError> {
233        PromptRequest::new(self, prompt)
234            .with_history(&mut chat_history)
235            .await
236    }
237}
238
239impl<M> StreamingCompletion<M> for Agent<M>
240where
241    M: CompletionModel,
242{
243    async fn stream_completion(
244        &self,
245        prompt: impl Into<Message> + WasmCompatSend,
246        chat_history: Vec<Message>,
247    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
248        // Reuse the existing completion implementation to build the request
249        // This ensures streaming and non-streaming use the same request building logic
250        self.completion(prompt, chat_history).await
251    }
252}
253
254impl<M> StreamingPrompt<M, M::StreamingResponse> for Agent<M>
255where
256    M: CompletionModel + 'static,
257    M::StreamingResponse: GetTokenUsage,
258{
259    fn stream_prompt(
260        &self,
261        prompt: impl Into<Message> + WasmCompatSend,
262    ) -> StreamingPromptRequest<M, ()> {
263        let arc = Arc::new(self.clone());
264        StreamingPromptRequest::new(arc, prompt)
265    }
266}
267
268impl<M> StreamingChat<M, M::StreamingResponse> for Agent<M>
269where
270    M: CompletionModel + 'static,
271    M::StreamingResponse: GetTokenUsage,
272{
273    fn stream_chat(
274        &self,
275        prompt: impl Into<Message> + WasmCompatSend,
276        chat_history: Vec<Message>,
277    ) -> StreamingPromptRequest<M, ()> {
278        let arc = Arc::new(self.clone());
279        StreamingPromptRequest::new(arc, prompt).with_history(chat_history)
280    }
281}