Skip to main content

rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use std::{collections::HashMap, sync::Arc};
15
16const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
17
18pub type DynamicContextStore = Arc<
19    Vec<(
20        usize,
21        Arc<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
22    )>,
23>;
24
25/// Helper function to build a completion request from agent components.
26/// This is used by both `Agent::completion()` and `PromptRequest::send()`.
27#[allow(clippy::too_many_arguments)]
28pub(crate) async fn build_completion_request<M: CompletionModel>(
29    model: &Arc<M>,
30    prompt: Message,
31    chat_history: &[Message],
32    preamble: Option<&str>,
33    static_context: &[Document],
34    temperature: Option<f64>,
35    max_tokens: Option<u64>,
36    additional_params: Option<&serde_json::Value>,
37    tool_choice: Option<&ToolChoice>,
38    tool_server_handle: &ToolServerHandle,
39    dynamic_context: &DynamicContextStore,
40    output_schema: Option<&schemars::Schema>,
41) -> Result<CompletionRequestBuilder<M>, CompletionError> {
42    // Find the latest message in the chat history that contains RAG text
43    let rag_text = prompt.rag_text();
44    let rag_text = rag_text.or_else(|| {
45        chat_history
46            .iter()
47            .rev()
48            .find_map(|message| message.rag_text())
49    });
50
51    // Prepend preamble as system message if present
52    let chat_history: Vec<Message> = if let Some(preamble) = preamble {
53        std::iter::once(Message::system(preamble.to_owned()))
54            .chain(chat_history.iter().cloned())
55            .collect()
56    } else {
57        chat_history.to_vec()
58    };
59
60    let completion_request = model
61        .completion_request(prompt)
62        .messages(chat_history)
63        .temperature_opt(temperature)
64        .max_tokens_opt(max_tokens)
65        .additional_params_opt(additional_params.cloned())
66        .output_schema_opt(output_schema.cloned())
67        .documents(static_context.to_vec());
68
69    let completion_request = if let Some(tool_choice) = tool_choice {
70        completion_request.tool_choice(tool_choice.clone())
71    } else {
72        completion_request
73    };
74
75    // If the agent has RAG text, we need to fetch the dynamic context and tools
76    let result = match &rag_text {
77        Some(text) => {
78            // Map over the vector to create async tasks
79            let search_futures = dynamic_context.iter().map(|(num_sample, index)| {
80                // Clone values to move into the async block
81                let text = text.clone();
82                let num_sample = *num_sample;
83                let index = index.clone();
84
85                async move {
86                    let req = VectorSearchRequest::builder()
87                        .query(text)
88                        .samples(num_sample as u64)
89                        .build();
90
91                    let docs = index
92                        .top_n(req)
93                        .await?
94                        .into_iter()
95                        .map(|(_, id, doc)| {
96                            // Pretty print the document if possible for better readability
97                            let text = serde_json::to_string_pretty(&doc)
98                                .unwrap_or_else(|_| doc.to_string());
99
100                            Document {
101                                id,
102                                text,
103                                additional_props: HashMap::new(),
104                            }
105                        })
106                        .collect::<Vec<_>>();
107
108                    Ok::<_, VectorStoreError>(docs)
109                }
110            });
111
112            // Await all vector searches concurrently
113            let fetched_context: Vec<Document> = futures::future::try_join_all(search_futures)
114                .await
115                .map_err(|e| CompletionError::RequestError(Box::new(e)))?
116                .into_iter()
117                .flatten() // Flatten the Vec<Vec<Document>> into Vec<Document>
118                .collect();
119
120            let tooldefs = tool_server_handle
121                .get_tool_defs(Some(text.to_string()))
122                .await
123                .map_err(|_| {
124                    CompletionError::RequestError("Failed to get tool definitions".into())
125                })?;
126
127            completion_request
128                .documents(fetched_context)
129                .tools(tooldefs)
130        }
131        None => {
132            let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
133                CompletionError::RequestError("Failed to get tool definitions".into())
134            })?;
135
136            completion_request.tools(tooldefs)
137        }
138    };
139
140    Ok(result)
141}
142
143/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
144/// (i.e.: system prompt) and a static set of context documents and tools.
145/// All context documents and tools are always provided to the agent when prompted.
146///
147/// The optional type parameter `P` represents a default hook that will be used for all
148/// prompt requests unless overridden via `.with_hook()` on the request.
149///
150/// # Example
151/// ```
152/// use rig::{client::ProviderClient, completion::Prompt, providers::openai};
153///
154/// let openai = openai::Client::from_env()?;
155///
156/// let comedian_agent = openai
157///     .agent("gpt-4o")
158///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
159///     .temperature(0.9)
160///     .build();
161///
162/// let response = comedian_agent.prompt("Entertain me!")
163///     .await
164///     .expect("Failed to prompt the agent");
165/// ```
166#[derive(Clone)]
167#[non_exhaustive]
168pub struct Agent<M, P = ()>
169where
170    M: CompletionModel,
171    P: PromptHook<M>,
172{
173    /// Name of the agent used for logging and debugging
174    pub name: Option<String>,
175    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
176    pub description: Option<String>,
177    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
178    pub model: Arc<M>,
179    /// System prompt
180    pub preamble: Option<String>,
181    /// Context documents always available to the agent
182    pub static_context: Vec<Document>,
183    /// Temperature of the model
184    pub temperature: Option<f64>,
185    /// Maximum number of tokens for the completion
186    pub max_tokens: Option<u64>,
187    /// Additional parameters to be passed to the model
188    pub additional_params: Option<serde_json::Value>,
189    pub tool_server_handle: ToolServerHandle,
190    /// List of vector store, with the sample number
191    pub dynamic_context: DynamicContextStore,
192    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
193    pub tool_choice: Option<ToolChoice>,
194    /// Default maximum depth for recursive agent calls
195    pub default_max_turns: Option<usize>,
196    /// Default hook for this agent, used when no per-request hook is provided
197    pub hook: Option<P>,
198    /// Optional JSON Schema for structured output. When set, providers that support
199    /// native structured outputs will constrain the model's response to match this schema.
200    pub output_schema: Option<schemars::Schema>,
201}
202
203impl<M, P> Agent<M, P>
204where
205    M: CompletionModel,
206    P: PromptHook<M>,
207{
208    /// Returns the name of the agent.
209    pub(crate) fn name(&self) -> &str {
210        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
211    }
212}
213
214impl<M, P> Completion<M> for Agent<M, P>
215where
216    M: CompletionModel,
217    P: PromptHook<M>,
218{
219    async fn completion<I, T>(
220        &self,
221        prompt: impl Into<Message> + WasmCompatSend,
222        chat_history: I,
223    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
224    where
225        I: IntoIterator<Item = T>,
226        T: Into<Message>,
227    {
228        let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
229        build_completion_request(
230            &self.model,
231            prompt.into(),
232            &history,
233            self.preamble.as_deref(),
234            &self.static_context,
235            self.temperature,
236            self.max_tokens,
237            self.additional_params.as_ref(),
238            self.tool_choice.as_ref(),
239            &self.tool_server_handle,
240            &self.dynamic_context,
241            self.output_schema.as_ref(),
242        )
243        .await
244    }
245}
246
247// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
248//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
249//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
250//
251// References:
252//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
253
254#[allow(refining_impl_trait)]
255impl<M, P> Prompt for Agent<M, P>
256where
257    M: CompletionModel + 'static,
258    P: PromptHook<M> + 'static,
259{
260    fn prompt(
261        &self,
262        prompt: impl Into<Message> + WasmCompatSend,
263    ) -> PromptRequest<prompt_request::Standard, M, P> {
264        PromptRequest::from_agent(self, prompt)
265    }
266}
267
268#[allow(refining_impl_trait)]
269impl<M, P> Prompt for &Agent<M, P>
270where
271    M: CompletionModel + 'static,
272    P: PromptHook<M> + 'static,
273{
274    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
275    fn prompt(
276        &self,
277        prompt: impl Into<Message> + WasmCompatSend,
278    ) -> PromptRequest<prompt_request::Standard, M, P> {
279        PromptRequest::from_agent(*self, prompt)
280    }
281}
282
283#[allow(refining_impl_trait)]
284impl<M, P> Chat for Agent<M, P>
285where
286    M: CompletionModel + 'static,
287    P: PromptHook<M> + 'static,
288{
289    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
290    async fn chat<I, T>(
291        &self,
292        prompt: impl Into<Message> + WasmCompatSend,
293        chat_history: I,
294    ) -> Result<String, PromptError>
295    where
296        I: IntoIterator<Item = T>,
297        T: Into<Message>,
298    {
299        PromptRequest::from_agent(self, prompt)
300            .with_history(chat_history)
301            .await
302    }
303}
304
305impl<M, P> StreamingCompletion<M> for Agent<M, P>
306where
307    M: CompletionModel,
308    P: PromptHook<M>,
309{
310    async fn stream_completion<I, T>(
311        &self,
312        prompt: impl Into<Message> + WasmCompatSend,
313        chat_history: I,
314    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
315    where
316        I: IntoIterator<Item = T> + WasmCompatSend,
317        T: Into<Message>,
318    {
319        // Reuse the existing completion implementation to build the request
320        // This ensures streaming and non-streaming use the same request building logic
321        self.completion(prompt, chat_history).await
322    }
323}
324
325impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
326where
327    M: CompletionModel + 'static,
328    M::StreamingResponse: GetTokenUsage,
329    P: PromptHook<M> + 'static,
330{
331    type Hook = P;
332
333    fn stream_prompt(
334        &self,
335        prompt: impl Into<Message> + WasmCompatSend,
336    ) -> StreamingPromptRequest<M, P> {
337        StreamingPromptRequest::<M, P>::from_agent(self, prompt)
338    }
339}
340
341impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
342where
343    M: CompletionModel + 'static,
344    M::StreamingResponse: GetTokenUsage,
345    P: PromptHook<M> + 'static,
346{
347    type Hook = P;
348
349    fn stream_chat<I, T>(
350        &self,
351        prompt: impl Into<Message> + WasmCompatSend,
352        chat_history: I,
353    ) -> StreamingPromptRequest<M, P>
354    where
355        I: IntoIterator<Item = T>,
356        T: Into<Message>,
357    {
358        StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
359    }
360}
361
362use crate::agent::prompt_request::TypedPromptRequest;
363use schemars::JsonSchema;
364use serde::de::DeserializeOwned;
365
366#[allow(refining_impl_trait)]
367impl<M, P> TypedPrompt for Agent<M, P>
368where
369    M: CompletionModel + 'static,
370    P: PromptHook<M> + 'static,
371{
372    type TypedRequest<T>
373        = TypedPromptRequest<T, prompt_request::Standard, M, P>
374    where
375        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
376
377    /// Send a prompt and receive a typed structured response.
378    ///
379    /// The JSON schema for `T` is automatically generated and sent to the provider.
380    /// Providers that support native structured outputs will constrain the model's
381    /// response to match this schema.
382    ///
383    /// # Example
384    /// ```rust,ignore
385    /// use rig::prelude::*;
386    /// use schemars::JsonSchema;
387    /// use serde::Deserialize;
388    ///
389    /// #[derive(Debug, Deserialize, JsonSchema)]
390    /// struct WeatherForecast {
391    ///     city: String,
392    ///     temperature_f: f64,
393    ///     conditions: String,
394    /// }
395    ///
396    /// let agent = client.agent("gpt-4o").build();
397    ///
398    /// // Type inferred from variable
399    /// let forecast: WeatherForecast = agent
400    ///     .prompt_typed("What's the weather in NYC?")
401    ///     .await?;
402    ///
403    /// // Or explicit turbofish syntax
404    /// let forecast = agent
405    ///     .prompt_typed::<WeatherForecast>("What's the weather in NYC?")
406    ///     .max_turns(3)
407    ///     .await?;
408    /// ```
409    fn prompt_typed<T>(
410        &self,
411        prompt: impl Into<Message> + WasmCompatSend,
412    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
413    where
414        T: JsonSchema + DeserializeOwned + WasmCompatSend,
415    {
416        TypedPromptRequest::from_agent(self, prompt)
417    }
418}
419
420#[allow(refining_impl_trait)]
421impl<M, P> TypedPrompt for &Agent<M, P>
422where
423    M: CompletionModel + 'static,
424    P: PromptHook<M> + 'static,
425{
426    type TypedRequest<T>
427        = TypedPromptRequest<T, prompt_request::Standard, M, P>
428    where
429        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
430
431    fn prompt_typed<T>(
432        &self,
433        prompt: impl Into<Message> + WasmCompatSend,
434    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
435    where
436        T: JsonSchema + DeserializeOwned + WasmCompatSend,
437    {
438        TypedPromptRequest::from_agent(*self, prompt)
439    }
440}