Skip to main content

rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock as TokioRwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21    TokioRwLock<
22        Vec<(
23            usize,
24            Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25        )>,
26    >,
27>;
28
29/// Helper function to build a completion request from agent components.
30/// This is used by both `Agent::completion()` and `PromptRequest::send()`.
31#[allow(clippy::too_many_arguments)]
32pub(crate) async fn build_completion_request<M: CompletionModel>(
33    model: &Arc<M>,
34    prompt: Message,
35    chat_history: Vec<Message>,
36    preamble: Option<&str>,
37    static_context: &[Document],
38    temperature: Option<f64>,
39    max_tokens: Option<u64>,
40    additional_params: Option<&serde_json::Value>,
41    tool_choice: Option<&ToolChoice>,
42    tool_server_handle: &ToolServerHandle,
43    dynamic_context: &DynamicContextStore,
44    output_schema: Option<&schemars::Schema>,
45) -> Result<CompletionRequestBuilder<M>, CompletionError> {
46    // Find the latest message in the chat history that contains RAG text
47    let rag_text = prompt.rag_text();
48    let rag_text = rag_text.or_else(|| {
49        chat_history
50            .iter()
51            .rev()
52            .find_map(|message| message.rag_text())
53    });
54
55    let chat_history = if let Some(preamble) = preamble {
56        let mut with_system = Vec::with_capacity(chat_history.len() + 1);
57        with_system.push(Message::system(preamble.to_owned()));
58        with_system.extend(chat_history);
59        with_system
60    } else {
61        chat_history
62    };
63
64    let completion_request = model
65        .completion_request(prompt)
66        .messages(chat_history)
67        .temperature_opt(temperature)
68        .max_tokens_opt(max_tokens)
69        .additional_params_opt(additional_params.cloned())
70        .output_schema_opt(output_schema.cloned())
71        .documents(static_context.to_vec());
72
73    let completion_request = if let Some(tool_choice) = tool_choice {
74        completion_request.tool_choice(tool_choice.clone())
75    } else {
76        completion_request
77    };
78
79    // If the agent has RAG text, we need to fetch the dynamic context and tools
80    let result = match &rag_text {
81        Some(text) => {
82            let fetched_context = stream::iter(dynamic_context.read().await.iter())
83                .then(|(num_sample, index)| async {
84                    let req = VectorSearchRequest::builder()
85                        .query(text)
86                        .samples(*num_sample as u64)
87                        .build()
88                        .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
89                    Ok::<_, VectorStoreError>(
90                        index
91                            .top_n(req)
92                            .await?
93                            .into_iter()
94                            .map(|(_, id, doc)| {
95                                // Pretty print the document if possible for better readability
96                                let text = serde_json::to_string_pretty(&doc)
97                                    .unwrap_or_else(|_| doc.to_string());
98
99                                Document {
100                                    id,
101                                    text,
102                                    additional_props: HashMap::new(),
103                                }
104                            })
105                            .collect::<Vec<_>>(),
106                    )
107                })
108                .try_fold(vec![], |mut acc, docs| async {
109                    acc.extend(docs);
110                    Ok(acc)
111                })
112                .await
113                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
114
115            let tooldefs = tool_server_handle
116                .get_tool_defs(Some(text.to_string()))
117                .await
118                .map_err(|_| {
119                    CompletionError::RequestError("Failed to get tool definitions".into())
120                })?;
121
122            completion_request
123                .documents(fetched_context)
124                .tools(tooldefs)
125        }
126        None => {
127            let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
128                CompletionError::RequestError("Failed to get tool definitions".into())
129            })?;
130
131            completion_request.tools(tooldefs)
132        }
133    };
134
135    Ok(result)
136}
137
138/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
139/// (i.e.: system prompt) and a static set of context documents and tools.
140/// All context documents and tools are always provided to the agent when prompted.
141///
142/// The optional type parameter `P` represents a default hook that will be used for all
143/// prompt requests unless overridden via `.with_hook()` on the request.
144///
145/// # Example
146/// ```
147/// use rig::{completion::Prompt, providers::openai};
148///
149/// let openai = openai::Client::from_env();
150///
151/// let comedian_agent = openai
152///     .agent("gpt-4o")
153///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
154///     .temperature(0.9)
155///     .build();
156///
157/// let response = comedian_agent.prompt("Entertain me!")
158///     .await
159///     .expect("Failed to prompt the agent");
160/// ```
161#[derive(Clone)]
162#[non_exhaustive]
163pub struct Agent<M, P = ()>
164where
165    M: CompletionModel,
166    P: PromptHook<M>,
167{
168    /// Name of the agent used for logging and debugging
169    pub name: Option<String>,
170    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
171    pub description: Option<String>,
172    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
173    pub model: Arc<M>,
174    /// System prompt
175    pub preamble: Option<String>,
176    /// Context documents always available to the agent
177    pub static_context: Vec<Document>,
178    /// Temperature of the model
179    pub temperature: Option<f64>,
180    /// Maximum number of tokens for the completion
181    pub max_tokens: Option<u64>,
182    /// Additional parameters to be passed to the model
183    pub additional_params: Option<serde_json::Value>,
184    pub tool_server_handle: ToolServerHandle,
185    /// List of vector store, with the sample number
186    pub dynamic_context: DynamicContextStore,
187    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
188    pub tool_choice: Option<ToolChoice>,
189    /// Default maximum depth for recursive agent calls
190    pub default_max_turns: Option<usize>,
191    /// Default hook for this agent, used when no per-request hook is provided
192    pub hook: Option<P>,
193    /// Optional JSON Schema for structured output. When set, providers that support
194    /// native structured outputs will constrain the model's response to match this schema.
195    pub output_schema: Option<schemars::Schema>,
196}
197
198impl<M, P> Agent<M, P>
199where
200    M: CompletionModel,
201    P: PromptHook<M>,
202{
203    /// Returns the name of the agent.
204    pub(crate) fn name(&self) -> &str {
205        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
206    }
207}
208
209impl<M, P> Completion<M> for Agent<M, P>
210where
211    M: CompletionModel,
212    P: PromptHook<M>,
213{
214    async fn completion(
215        &self,
216        prompt: impl Into<Message> + WasmCompatSend,
217        chat_history: Vec<Message>,
218    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
219        build_completion_request(
220            &self.model,
221            prompt.into(),
222            chat_history,
223            self.preamble.as_deref(),
224            &self.static_context,
225            self.temperature,
226            self.max_tokens,
227            self.additional_params.as_ref(),
228            self.tool_choice.as_ref(),
229            &self.tool_server_handle,
230            &self.dynamic_context,
231            self.output_schema.as_ref(),
232        )
233        .await
234    }
235}
236
237// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
238//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
239//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
240//
241// References:
242//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
243
244#[allow(refining_impl_trait)]
245impl<M, P> Prompt for Agent<M, P>
246where
247    M: CompletionModel,
248    P: PromptHook<M> + 'static,
249{
250    fn prompt(
251        &self,
252        prompt: impl Into<Message> + WasmCompatSend,
253    ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
254        PromptRequest::from_agent(self, prompt)
255    }
256}
257
258#[allow(refining_impl_trait)]
259impl<M, P> Prompt for &Agent<M, P>
260where
261    M: CompletionModel,
262    P: PromptHook<M> + 'static,
263{
264    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
265    fn prompt(
266        &self,
267        prompt: impl Into<Message> + WasmCompatSend,
268    ) -> PromptRequest<'_, prompt_request::Standard, M, P> {
269        PromptRequest::from_agent(*self, prompt)
270    }
271}
272
273#[allow(refining_impl_trait)]
274impl<M, P> Chat for Agent<M, P>
275where
276    M: CompletionModel,
277    P: PromptHook<M> + 'static,
278{
279    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
280    async fn chat(
281        &self,
282        prompt: impl Into<Message> + WasmCompatSend,
283        mut chat_history: Vec<Message>,
284    ) -> Result<String, PromptError> {
285        PromptRequest::from_agent(self, prompt)
286            .with_history(&mut chat_history)
287            .await
288    }
289}
290
291impl<M, P> StreamingCompletion<M> for Agent<M, P>
292where
293    M: CompletionModel,
294    P: PromptHook<M>,
295{
296    async fn stream_completion(
297        &self,
298        prompt: impl Into<Message> + WasmCompatSend,
299        chat_history: Vec<Message>,
300    ) -> Result<CompletionRequestBuilder<M>, CompletionError> {
301        // Reuse the existing completion implementation to build the request
302        // This ensures streaming and non-streaming use the same request building logic
303        self.completion(prompt, chat_history).await
304    }
305}
306
307impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
308where
309    M: CompletionModel + 'static,
310    M::StreamingResponse: GetTokenUsage,
311    P: PromptHook<M> + 'static,
312{
313    type Hook = P;
314
315    fn stream_prompt(
316        &self,
317        prompt: impl Into<Message> + WasmCompatSend,
318    ) -> StreamingPromptRequest<M, P> {
319        StreamingPromptRequest::<M, P>::from_agent(self, prompt)
320    }
321}
322
323impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
324where
325    M: CompletionModel + 'static,
326    M::StreamingResponse: GetTokenUsage,
327    P: PromptHook<M> + 'static,
328{
329    type Hook = P;
330
331    fn stream_chat(
332        &self,
333        prompt: impl Into<Message> + WasmCompatSend,
334        chat_history: Vec<Message>,
335    ) -> StreamingPromptRequest<M, P> {
336        StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
337    }
338}
339
340use crate::agent::prompt_request::TypedPromptRequest;
341use schemars::JsonSchema;
342use serde::de::DeserializeOwned;
343
344#[allow(refining_impl_trait)]
345impl<M, P> TypedPrompt for Agent<M, P>
346where
347    M: CompletionModel,
348    P: PromptHook<M> + 'static,
349{
350    type TypedRequest<'a, T>
351        = TypedPromptRequest<'a, T, prompt_request::Standard, M, P>
352    where
353        Self: 'a,
354        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
355
356    /// Send a prompt and receive a typed structured response.
357    ///
358    /// The JSON schema for `T` is automatically generated and sent to the provider.
359    /// Providers that support native structured outputs will constrain the model's
360    /// response to match this schema.
361    ///
362    /// # Example
363    /// ```rust,ignore
364    /// use rig::prelude::*;
365    /// use schemars::JsonSchema;
366    /// use serde::Deserialize;
367    ///
368    /// #[derive(Debug, Deserialize, JsonSchema)]
369    /// struct WeatherForecast {
370    ///     city: String,
371    ///     temperature_f: f64,
372    ///     conditions: String,
373    /// }
374    ///
375    /// let agent = client.agent("gpt-4o").build();
376    ///
377    /// // Type inferred from variable
378    /// let forecast: WeatherForecast = agent
379    ///     .prompt_typed("What's the weather in NYC?")
380    ///     .await?;
381    ///
382    /// // Or explicit turbofish syntax
383    /// let forecast = agent
384    ///     .prompt_typed::<WeatherForecast>("What's the weather in NYC?")
385    ///     .max_turns(3)
386    ///     .await?;
387    /// ```
388    fn prompt_typed<T>(
389        &self,
390        prompt: impl Into<Message> + WasmCompatSend,
391    ) -> TypedPromptRequest<'_, T, prompt_request::Standard, M, P>
392    where
393        T: JsonSchema + DeserializeOwned + WasmCompatSend,
394    {
395        TypedPromptRequest::from_agent(self, prompt)
396    }
397}
398
399#[allow(refining_impl_trait)]
400impl<M, P> TypedPrompt for &Agent<M, P>
401where
402    M: CompletionModel,
403    P: PromptHook<M> + 'static,
404{
405    type TypedRequest<'a, T>
406        = TypedPromptRequest<'a, T, prompt_request::Standard, M, P>
407    where
408        Self: 'a,
409        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
410
411    fn prompt_typed<T>(
412        &self,
413        prompt: impl Into<Message> + WasmCompatSend,
414    ) -> TypedPromptRequest<'_, T, prompt_request::Standard, M, P>
415    where
416        T: JsonSchema + DeserializeOwned + WasmCompatSend,
417    {
418        TypedPromptRequest::from_agent(*self, prompt)
419    }
420}