Skip to main content

rig/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use futures::{StreamExt, TryStreamExt, stream};
15use std::{collections::HashMap, sync::Arc};
16use tokio::sync::RwLock as TokioRwLock;
17
18const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
19
20pub type DynamicContextStore = Arc<
21    TokioRwLock<
22        Vec<(
23            usize,
24            Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25        )>,
26    >,
27>;
28
29/// Helper function to build a completion request from agent components.
30/// This is used by both `Agent::completion()` and `PromptRequest::send()`.
31#[allow(clippy::too_many_arguments)]
32pub(crate) async fn build_completion_request<M: CompletionModel>(
33    model: &Arc<M>,
34    prompt: Message,
35    chat_history: &[Message],
36    preamble: Option<&str>,
37    static_context: &[Document],
38    temperature: Option<f64>,
39    max_tokens: Option<u64>,
40    additional_params: Option<&serde_json::Value>,
41    tool_choice: Option<&ToolChoice>,
42    tool_server_handle: &ToolServerHandle,
43    dynamic_context: &DynamicContextStore,
44    output_schema: Option<&schemars::Schema>,
45) -> Result<CompletionRequestBuilder<M>, CompletionError> {
46    // Find the latest message in the chat history that contains RAG text
47    let rag_text = prompt.rag_text();
48    let rag_text = rag_text.or_else(|| {
49        chat_history
50            .iter()
51            .rev()
52            .find_map(|message| message.rag_text())
53    });
54
55    // Prepend preamble as system message if present
56    let chat_history: Vec<Message> = if let Some(preamble) = preamble {
57        std::iter::once(Message::system(preamble.to_owned()))
58            .chain(chat_history.iter().cloned())
59            .collect()
60    } else {
61        chat_history.to_vec()
62    };
63
64    let completion_request = model
65        .completion_request(prompt)
66        .messages(chat_history)
67        .temperature_opt(temperature)
68        .max_tokens_opt(max_tokens)
69        .additional_params_opt(additional_params.cloned())
70        .output_schema_opt(output_schema.cloned())
71        .documents(static_context.to_vec());
72
73    let completion_request = if let Some(tool_choice) = tool_choice {
74        completion_request.tool_choice(tool_choice.clone())
75    } else {
76        completion_request
77    };
78
79    // If the agent has RAG text, we need to fetch the dynamic context and tools
80    let result = match &rag_text {
81        Some(text) => {
82            let fetched_context = stream::iter(dynamic_context.read().await.iter())
83                .then(|(num_sample, index)| async {
84                    let req = VectorSearchRequest::builder()
85                        .query(text)
86                        .samples(*num_sample as u64)
87                        .build()
88                        .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
89                    Ok::<_, VectorStoreError>(
90                        index
91                            .top_n(req)
92                            .await?
93                            .into_iter()
94                            .map(|(_, id, doc)| {
95                                // Pretty print the document if possible for better readability
96                                let text = serde_json::to_string_pretty(&doc)
97                                    .unwrap_or_else(|_| doc.to_string());
98
99                                Document {
100                                    id,
101                                    text,
102                                    additional_props: HashMap::new(),
103                                }
104                            })
105                            .collect::<Vec<_>>(),
106                    )
107                })
108                .try_fold(vec![], |mut acc, docs| async {
109                    acc.extend(docs);
110                    Ok(acc)
111                })
112                .await
113                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
114
115            let tooldefs = tool_server_handle
116                .get_tool_defs(Some(text.to_string()))
117                .await
118                .map_err(|_| {
119                    CompletionError::RequestError("Failed to get tool definitions".into())
120                })?;
121
122            completion_request
123                .documents(fetched_context)
124                .tools(tooldefs)
125        }
126        None => {
127            let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
128                CompletionError::RequestError("Failed to get tool definitions".into())
129            })?;
130
131            completion_request.tools(tooldefs)
132        }
133    };
134
135    Ok(result)
136}
137
138/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
139/// (i.e.: system prompt) and a static set of context documents and tools.
140/// All context documents and tools are always provided to the agent when prompted.
141///
142/// The optional type parameter `P` represents a default hook that will be used for all
143/// prompt requests unless overridden via `.with_hook()` on the request.
144///
145/// # Example
146/// ```
147/// use rig::{completion::Prompt, providers::openai};
148///
149/// let openai = openai::Client::from_env();
150///
151/// let comedian_agent = openai
152///     .agent("gpt-4o")
153///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
154///     .temperature(0.9)
155///     .build();
156///
157/// let response = comedian_agent.prompt("Entertain me!")
158///     .await
159///     .expect("Failed to prompt the agent");
160/// ```
161#[derive(Clone)]
162#[non_exhaustive]
163pub struct Agent<M, P = ()>
164where
165    M: CompletionModel,
166    P: PromptHook<M>,
167{
168    /// Name of the agent used for logging and debugging
169    pub name: Option<String>,
170    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
171    pub description: Option<String>,
172    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
173    pub model: Arc<M>,
174    /// System prompt
175    pub preamble: Option<String>,
176    /// Context documents always available to the agent
177    pub static_context: Vec<Document>,
178    /// Temperature of the model
179    pub temperature: Option<f64>,
180    /// Maximum number of tokens for the completion
181    pub max_tokens: Option<u64>,
182    /// Additional parameters to be passed to the model
183    pub additional_params: Option<serde_json::Value>,
184    pub tool_server_handle: ToolServerHandle,
185    /// List of vector store, with the sample number
186    pub dynamic_context: DynamicContextStore,
187    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
188    pub tool_choice: Option<ToolChoice>,
189    /// Default maximum depth for recursive agent calls
190    pub default_max_turns: Option<usize>,
191    /// Default hook for this agent, used when no per-request hook is provided
192    pub hook: Option<P>,
193    /// Optional JSON Schema for structured output. When set, providers that support
194    /// native structured outputs will constrain the model's response to match this schema.
195    pub output_schema: Option<schemars::Schema>,
196}
197
198impl<M, P> Agent<M, P>
199where
200    M: CompletionModel,
201    P: PromptHook<M>,
202{
203    /// Returns the name of the agent.
204    pub(crate) fn name(&self) -> &str {
205        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
206    }
207}
208
209impl<M, P> Completion<M> for Agent<M, P>
210where
211    M: CompletionModel,
212    P: PromptHook<M>,
213{
214    async fn completion<I, T>(
215        &self,
216        prompt: impl Into<Message> + WasmCompatSend,
217        chat_history: I,
218    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
219    where
220        I: IntoIterator<Item = T>,
221        T: Into<Message>,
222    {
223        let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
224        build_completion_request(
225            &self.model,
226            prompt.into(),
227            &history,
228            self.preamble.as_deref(),
229            &self.static_context,
230            self.temperature,
231            self.max_tokens,
232            self.additional_params.as_ref(),
233            self.tool_choice.as_ref(),
234            &self.tool_server_handle,
235            &self.dynamic_context,
236            self.output_schema.as_ref(),
237        )
238        .await
239    }
240}
241
242// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
243//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
244//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
245//
246// References:
247//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
248
249#[allow(refining_impl_trait)]
250impl<M, P> Prompt for Agent<M, P>
251where
252    M: CompletionModel + 'static,
253    P: PromptHook<M> + 'static,
254{
255    fn prompt(
256        &self,
257        prompt: impl Into<Message> + WasmCompatSend,
258    ) -> PromptRequest<prompt_request::Standard, M, P> {
259        PromptRequest::from_agent(self, prompt)
260    }
261}
262
263#[allow(refining_impl_trait)]
264impl<M, P> Prompt for &Agent<M, P>
265where
266    M: CompletionModel + 'static,
267    P: PromptHook<M> + 'static,
268{
269    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
270    fn prompt(
271        &self,
272        prompt: impl Into<Message> + WasmCompatSend,
273    ) -> PromptRequest<prompt_request::Standard, M, P> {
274        PromptRequest::from_agent(*self, prompt)
275    }
276}
277
278#[allow(refining_impl_trait)]
279impl<M, P> Chat for Agent<M, P>
280where
281    M: CompletionModel + 'static,
282    P: PromptHook<M> + 'static,
283{
284    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
285    async fn chat<I, T>(
286        &self,
287        prompt: impl Into<Message> + WasmCompatSend,
288        chat_history: I,
289    ) -> Result<String, PromptError>
290    where
291        I: IntoIterator<Item = T>,
292        T: Into<Message>,
293    {
294        PromptRequest::from_agent(self, prompt)
295            .with_history(chat_history)
296            .await
297    }
298}
299
300impl<M, P> StreamingCompletion<M> for Agent<M, P>
301where
302    M: CompletionModel,
303    P: PromptHook<M>,
304{
305    async fn stream_completion<I, T>(
306        &self,
307        prompt: impl Into<Message> + WasmCompatSend,
308        chat_history: I,
309    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
310    where
311        I: IntoIterator<Item = T> + WasmCompatSend,
312        T: Into<Message>,
313    {
314        // Reuse the existing completion implementation to build the request
315        // This ensures streaming and non-streaming use the same request building logic
316        self.completion(prompt, chat_history).await
317    }
318}
319
320impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
321where
322    M: CompletionModel + 'static,
323    M::StreamingResponse: GetTokenUsage,
324    P: PromptHook<M> + 'static,
325{
326    type Hook = P;
327
328    fn stream_prompt(
329        &self,
330        prompt: impl Into<Message> + WasmCompatSend,
331    ) -> StreamingPromptRequest<M, P> {
332        StreamingPromptRequest::<M, P>::from_agent(self, prompt)
333    }
334}
335
336impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
337where
338    M: CompletionModel + 'static,
339    M::StreamingResponse: GetTokenUsage,
340    P: PromptHook<M> + 'static,
341{
342    type Hook = P;
343
344    fn stream_chat<I, T>(
345        &self,
346        prompt: impl Into<Message> + WasmCompatSend,
347        chat_history: I,
348    ) -> StreamingPromptRequest<M, P>
349    where
350        I: IntoIterator<Item = T>,
351        T: Into<Message>,
352    {
353        StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
354    }
355}
356
357use crate::agent::prompt_request::TypedPromptRequest;
358use schemars::JsonSchema;
359use serde::de::DeserializeOwned;
360
361#[allow(refining_impl_trait)]
362impl<M, P> TypedPrompt for Agent<M, P>
363where
364    M: CompletionModel + 'static,
365    P: PromptHook<M> + 'static,
366{
367    type TypedRequest<T>
368        = TypedPromptRequest<T, prompt_request::Standard, M, P>
369    where
370        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
371
372    /// Send a prompt and receive a typed structured response.
373    ///
374    /// The JSON schema for `T` is automatically generated and sent to the provider.
375    /// Providers that support native structured outputs will constrain the model's
376    /// response to match this schema.
377    ///
378    /// # Example
379    /// ```rust,ignore
380    /// use rig::prelude::*;
381    /// use schemars::JsonSchema;
382    /// use serde::Deserialize;
383    ///
384    /// #[derive(Debug, Deserialize, JsonSchema)]
385    /// struct WeatherForecast {
386    ///     city: String,
387    ///     temperature_f: f64,
388    ///     conditions: String,
389    /// }
390    ///
391    /// let agent = client.agent("gpt-4o").build();
392    ///
393    /// // Type inferred from variable
394    /// let forecast: WeatherForecast = agent
395    ///     .prompt_typed("What's the weather in NYC?")
396    ///     .await?;
397    ///
398    /// // Or explicit turbofish syntax
399    /// let forecast = agent
400    ///     .prompt_typed::<WeatherForecast>("What's the weather in NYC?")
401    ///     .max_turns(3)
402    ///     .await?;
403    /// ```
404    fn prompt_typed<T>(
405        &self,
406        prompt: impl Into<Message> + WasmCompatSend,
407    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
408    where
409        T: JsonSchema + DeserializeOwned + WasmCompatSend,
410    {
411        TypedPromptRequest::from_agent(self, prompt)
412    }
413}
414
415#[allow(refining_impl_trait)]
416impl<M, P> TypedPrompt for &Agent<M, P>
417where
418    M: CompletionModel + 'static,
419    P: PromptHook<M> + 'static,
420{
421    type TypedRequest<T>
422        = TypedPromptRequest<T, prompt_request::Standard, M, P>
423    where
424        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
425
426    fn prompt_typed<T>(
427        &self,
428        prompt: impl Into<Message> + WasmCompatSend,
429    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
430    where
431        T: JsonSchema + DeserializeOwned + WasmCompatSend,
432    {
433        TypedPromptRequest::from_agent(*self, prompt)
434    }
435}