Skip to main content

rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock as TokioRwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21    TokioRwLock<
22        Vec<(
23            usize,
24            Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25        )>,
26    >,
27>;
28
29/// Helper function to build a completion request from agent components.
30/// This is used by both `Agent::completion()` and `PromptRequest::send()`.
31#[allow(clippy::too_many_arguments)]
32pub(crate) async fn build_completion_request<M: CompletionModel>(
33    model: &Arc<M>,
34    prompt: Message,
35    chat_history: Vec<Message>,
36    preamble: Option<&str>,
37    static_context: &[Document],
38    temperature: Option<f64>,
39    max_tokens: Option<u64>,
40    additional_params: Option<&serde_json::Value>,
41    tool_choice: Option<&ToolChoice>,
42    tool_server_handle: &ToolServerHandle,
43    dynamic_context: &DynamicContextStore,
44    output_schema: Option<&schemars::Schema>,
45) -> Result<CompletionRequestBuilder<M>, CompletionError> {
46    // Find the latest message in the chat history that contains RAG text
47    let rag_text = prompt.rag_text();
48    let rag_text = rag_text.or_else(|| {
49        chat_history
50            .iter()
51            .rev()
52            .find_map(|message| message.rag_text())
53    });
54
55    let completion_request = model
56        .completion_request(prompt)
57        .messages(chat_history)
58        .temperature_opt(temperature)
59        .max_tokens_opt(max_tokens)
60        .additional_params_opt(additional_params.cloned())
61        .output_schema_opt(output_schema.cloned())
62        .documents(static_context.to_vec());
63
64    let completion_request = if let Some(preamble) = preamble {
65        completion_request.preamble(preamble.to_owned())
66    } else {
67        completion_request
68    };
69
70    let completion_request = if let Some(tool_choice) = tool_choice {
71        completion_request.tool_choice(tool_choice.clone())
72    } else {
73        completion_request
74    };
75
76    // If the agent has RAG text, we need to fetch the dynamic context and tools
77    let result = match &rag_text {
78        Some(text) => {
79            let fetched_context = stream::iter(dynamic_context.read().await.iter())
80                .then(|(num_sample, index)| async {
81                    let req = VectorSearchRequest::builder()
82                        .query(text)
83                        .samples(*num_sample as u64)
84                        .build()
85                        .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
86                    Ok::<_, VectorStoreError>(
87                        index
88                            .top_n(req)
89                            .await?
90                            .into_iter()
91                            .map(|(_, id, doc)| {
92                                // Pretty print the document if possible for better readability
93                                let text = serde_json::to_string_pretty(&doc)
94                                    .unwrap_or_else(|_| doc.to_string());
95
96                                Document {
97                                    id,
98                                    text,
99                                    additional_props: HashMap::new(),
100                                }
101                            })
102                            .collect::<Vec<_>>(),
103                    )
104                })
105                .try_fold(vec![], |mut acc, docs| async {
106                    acc.extend(docs);
107                    Ok(acc)
108                })
109                .await
110                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
111
112            let tooldefs = tool_server_handle
113                .get_tool_defs(Some(text.to_string()))
114                .await
115                .map_err(|_| {
116                    CompletionError::RequestError("Failed to get tool definitions".into())
117                })?;
118
119            completion_request
120                .documents(fetched_context)
121                .tools(tooldefs)
122        }
123        None => {
124            let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
125                CompletionError::RequestError("Failed to get tool definitions".into())
126            })?;
127
128            completion_request.tools(tooldefs)
129        }
130    };
131
132    Ok(result)
133}
134
135/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
136/// (i.e.: system prompt) and a static set of context documents and tools.
137/// All context documents and tools are always provided to the agent when prompted.
138///
139/// The optional type parameter `P` represents a default hook that will be used for all
140/// prompt requests unless overridden via `.with_hook()` on the request.
141///
142/// # Example
143/// ```
144/// use rig::{completion::Prompt, providers::openai};
145///
146/// let openai = openai::Client::from_env();
147///
148/// let comedian_agent = openai
149///     .agent("gpt-4o")
150///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
151///     .temperature(0.9)
152///     .build();
153///
154/// let response = comedian_agent.prompt("Entertain me!")
155///     .await
156///     .expect("Failed to prompt the agent");
157/// ```
158#[derive(Clone)]
159#[non_exhaustive]
160pub struct Agent<M, P = ()>
161where
162    M: CompletionModel,
163    P: PromptHook<M>,
164{
165    /// Name of the agent used for logging and debugging
166    pub name: Option<String>,
167    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
168    pub description: Option<String>,
169    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
170    pub model: Arc<M>,
171    /// System prompt
172    pub preamble: Option<String>,
173    /// Context documents always available to the agent
174    pub static_context: Vec<Document>,
175    /// Temperature of the model
176    pub temperature: Option<f64>,
177    /// Maximum number of tokens for the completion
178    pub max_tokens: Option<u64>,
179    /// Additional parameters to be passed to the model
180    pub additional_params: Option<serde_json::Value>,
181    pub tool_server_handle: ToolServerHandle,
182    /// List of vector store, with the sample number
183    pub dynamic_context: DynamicContextStore,
184    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
185    pub tool_choice: Option<ToolChoice>,
186    /// Default maximum depth for recursive agent calls
187    pub default_max_turns: Option<usize>,
188    /// Default hook for this agent, used when no per-request hook is provided
189    pub hook: Option<P>,
190    /// Optional JSON Schema for structured output. When set, providers that support
191    /// native structured outputs will constrain the model's response to match this schema.
192    pub output_schema: Option<schemars::Schema>,
193}
194
195impl<M, P> Agent<M, P>
196where
197    M: CompletionModel,
198    P: PromptHook<M>,
199{
200    /// Returns the name of the agent.
201    pub(crate) fn name(&self) -> &str {
202        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
203    }
204}
205
206impl<M, P> Completion<M> for Agent<M, P>
207where
208    M: CompletionModel,
209    P: PromptHook<M>,
210{
211    async fn completion(
212        &self,
213        prompt: impl Into<Message> + WasmCompatSend,
214        chat_history: Vec<Message>,
215    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
216        build_completion_request(
217            &self.model,
218            prompt.into(),
219            chat_history,
220            self.preamble.as_deref(),
221            &self.static_context,
222            self.temperature,
223            self.max_tokens,
224            self.additional_params.as_ref(),
225            self.tool_choice.as_ref(),
226            &self.tool_server_handle,
227            &self.dynamic_context,
228            self.output_schema.as_ref(),
229        )
230        .await
231    }
232}
233
234// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
235//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
236//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
237//
238// References:
239//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
240
241#[allow(refining_impl_trait)]
242impl<M, P> Prompt for Agent<M, P>
243where
244    M: CompletionModel,
245    P: PromptHook<M> + 'static,
246{
247    fn prompt(
248        &self,
249        prompt: impl Into<Message> + WasmCompatSend,
250    ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
251        PromptRequest::from_agent(self, prompt)
252    }
253}
254
255#[allow(refining_impl_trait)]
256impl<M, P> Prompt for &Agent<M, P>
257where
258    M: CompletionModel,
259    P: PromptHook<M> + 'static,
260{
261    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
262    fn prompt(
263        &self,
264        prompt: impl Into<Message> + WasmCompatSend,
265    ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
266        PromptRequest::from_agent(*self, prompt)
267    }
268}
269
270#[allow(refining_impl_trait)]
271impl<M, P> Chat for Agent<M, P>
272where
273    M: CompletionModel,
274    P: PromptHook<M> + 'static,
275{
276    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
277    async fn chat(
278        &self,
279        prompt: impl Into<Message> + WasmCompatSend,
280        mut chat_history: Vec<Message>,
281    ) -> Result<String, PromptError> {
282        PromptRequest::from_agent(self, prompt)
283            .with_history(&mut chat_history)
284            .await
285    }
286}
287
288impl<M, P> StreamingCompletion<M> for Agent<M, P>
289where
290    M: CompletionModel,
291    P: PromptHook<M>,
292{
293    async fn stream_completion(
294        &self,
295        prompt: impl Into<Message> + WasmCompatSend,
296        chat_history: Vec<Message>,
297    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
298        // Reuse the existing completion implementation to build the request
299        // This ensures streaming and non-streaming use the same request building logic
300        self.completion(prompt, chat_history).await
301    }
302}
303
304impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
305where
306    M: CompletionModel + 'static,
307    M::StreamingResponse: GetTokenUsage,
308    P: PromptHook<M> + 'static,
309{
310    type Hook = P;
311
312    fn stream_prompt(
313        &self,
314        prompt: impl Into<Message> + WasmCompatSend,
315    ) -> StreamingPromptRequest<M, P> {
316        StreamingPromptRequest::<M, P>::from_agent(self, prompt)
317    }
318}
319
320impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
321where
322    M: CompletionModel + 'static,
323    M::StreamingResponse: GetTokenUsage,
324    P: PromptHook<M> + 'static,
325{
326    type Hook = P;
327
328    fn stream_chat(
329        &self,
330        prompt: impl Into<Message> + WasmCompatSend,
331        chat_history: Vec<Message>,
332    ) -> StreamingPromptRequest<M, P> {
333        StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
334    }
335}
336
337use crate::{agent::prompt_request::TypedPromptRequest, completion::TypedPrompt};
338use schemars::JsonSchema;
339use serde::de::DeserializeOwned;
340
341#[allow(refining_impl_trait)]
342impl<M, P> TypedPrompt for Agent<M, P>
343where
344    M: CompletionModel,
345    P: PromptHook<M> + 'static,
346{
347    type TypedRequest<'a, T>
348        = TypedPromptRequest<'a, T, M, P>
349    where
350        Self: 'a,
351        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
352
353    /// Send a prompt and receive a typed structured response.
354    ///
355    /// The JSON schema for `T` is automatically generated and sent to the provider.
356    /// Providers that support native structured outputs will constrain the model's
357    /// response to match this schema.
358    ///
359    /// # Example
360    /// ```rust,ignore
361    /// use rig::prelude::*;
362    /// use schemars::JsonSchema;
363    /// use serde::Deserialize;
364    ///
365    /// #[derive(Debug, Deserialize, JsonSchema)]
366    /// struct WeatherForecast {
367    ///     city: String,
368    ///     temperature_f: f64,
369    ///     conditions: String,
370    /// }
371    ///
372    /// let agent = client.agent("gpt-4o").build();
373    ///
374    /// // Type inferred from variable
375    /// let forecast: WeatherForecast = agent
376    ///     .prompt_typed("What's the weather in NYC?")
377    ///     .await?;
378    ///
379    /// // Or explicit turbofish syntax
380    /// let forecast = agent
381    ///     .prompt_typed::<WeatherForecast>("What's the weather in NYC?")
382    ///     .max_turns(3)
383    ///     .await?;
384    /// ```
385    fn prompt_typed<T>(
386        &self,
387        prompt: impl Into<Message> + WasmCompatSend,
388    ) -> TypedPromptRequest<'_, T, M, P>
389    where
390        T: JsonSchema + DeserializeOwned + WasmCompatSend,
391    {
392        TypedPromptRequest::from_agent(self, prompt)
393    }
394}
395
396#[allow(refining_impl_trait)]
397impl<M, P> TypedPrompt for &Agent<M, P>
398where
399    M: CompletionModel,
400    P: PromptHook<M> + 'static,
401{
402    type TypedRequest<'a, T>
403        = TypedPromptRequest<'a, T, M, P>
404    where
405        Self: 'a,
406        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
407
408    fn prompt_typed<T>(
409        &self,
410        prompt: impl Into<Message> + WasmCompatSend,
411    ) -> TypedPromptRequest<'_, T, M, P>
412    where
413        T: JsonSchema + DeserializeOwned + WasmCompatSend,
414    {
415        TypedPromptRequest::from_agent(*self, prompt)
416    }
417}