Skip to main content

rig_core/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use std::{collections::HashMap, sync::Arc};
15
16const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
17
18pub type DynamicContextStore = Arc<
19    Vec<(
20        usize,
21        Arc<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
22    )>,
23>;
24
25/// Helper function to build a completion request from agent components.
26/// This is used by both `Agent::completion()` and `PromptRequest::send()`.
27#[allow(clippy::too_many_arguments)]
28pub(crate) async fn build_completion_request<M: CompletionModel>(
29    model: &Arc<M>,
30    prompt: Message,
31    chat_history: &[Message],
32    preamble: Option<&str>,
33    static_context: &[Document],
34    temperature: Option<f64>,
35    max_tokens: Option<u64>,
36    additional_params: Option<&serde_json::Value>,
37    tool_choice: Option<&ToolChoice>,
38    tool_server_handle: &ToolServerHandle,
39    dynamic_context: &DynamicContextStore,
40    output_schema: Option<&schemars::Schema>,
41) -> Result<CompletionRequestBuilder<M>, CompletionError> {
42    // Find the latest message in the chat history that contains RAG text
43    let rag_text = prompt.rag_text();
44    let rag_text = rag_text.or_else(|| {
45        chat_history
46            .iter()
47            .rev()
48            .find_map(|message| message.rag_text())
49    });
50
51    // Prepend preamble as system message if present
52    let chat_history: Vec<Message> = if let Some(preamble) = preamble {
53        std::iter::once(Message::system(preamble.to_owned()))
54            .chain(chat_history.iter().cloned())
55            .collect()
56    } else {
57        chat_history.to_vec()
58    };
59
60    let completion_request = model
61        .completion_request(prompt)
62        .messages(chat_history)
63        .temperature_opt(temperature)
64        .max_tokens_opt(max_tokens)
65        .additional_params_opt(additional_params.cloned())
66        .output_schema_opt(output_schema.cloned())
67        .documents(static_context.to_vec());
68
69    let completion_request = if let Some(tool_choice) = tool_choice {
70        completion_request.tool_choice(tool_choice.clone())
71    } else {
72        completion_request
73    };
74
75    // If the agent has RAG text, we need to fetch the dynamic context and tools
76    let result = match &rag_text {
77        Some(text) => {
78            // Map over the vector to create async tasks
79            let search_futures = dynamic_context.iter().map(|(num_sample, index)| {
80                // Clone values to move into the async block
81                let text = text.clone();
82                let num_sample = *num_sample;
83                let index = index.clone();
84
85                async move {
86                    let req = VectorSearchRequest::builder()
87                        .query(text)
88                        .samples(num_sample as u64)
89                        .build();
90
91                    let docs = index
92                        .top_n(req)
93                        .await?
94                        .into_iter()
95                        .map(|(_, id, doc)| {
96                            // Pretty print the document if possible for better readability
97                            let text = serde_json::to_string_pretty(&doc)
98                                .unwrap_or_else(|_| doc.to_string());
99
100                            Document {
101                                id,
102                                text,
103                                additional_props: HashMap::new(),
104                            }
105                        })
106                        .collect::<Vec<_>>();
107
108                    Ok::<_, VectorStoreError>(docs)
109                }
110            });
111
112            // Await all vector searches concurrently
113            let fetched_context: Vec<Document> = futures::future::try_join_all(search_futures)
114                .await
115                .map_err(|e| CompletionError::RequestError(Box::new(e)))?
116                .into_iter()
117                .flatten() // Flatten the Vec<Vec<Document>> into Vec<Document>
118                .collect();
119
120            let tooldefs = tool_server_handle
121                .get_tool_defs(Some(text.to_string()))
122                .await
123                .map_err(|_| {
124                    CompletionError::RequestError("Failed to get tool definitions".into())
125                })?;
126
127            completion_request
128                .documents(fetched_context)
129                .tools(tooldefs)
130        }
131        None => {
132            let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
133                CompletionError::RequestError("Failed to get tool definitions".into())
134            })?;
135
136            completion_request.tools(tooldefs)
137        }
138    };
139
140    Ok(result)
141}
142
143/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
144/// (i.e.: system prompt) and a static set of context documents and tools.
145/// All context documents and tools are always provided to the agent when prompted.
146///
147/// The optional type parameter `P` represents a default hook that will be used for all
148/// prompt requests unless overridden via `.with_hook()` on the request.
149///
150/// # Example
151/// ```no_run
152/// use rig_core::{
153///     client::{CompletionClient, ProviderClient},
154///     completion::Prompt,
155///     providers::openai,
156/// };
157///
158/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
159/// let openai = openai::Client::from_env()?;
160///
161/// let comedian_agent = openai
162///     .agent(openai::GPT_5_2)
163///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
164///     .temperature(0.9)
165///     .build();
166///
167/// let response = comedian_agent.prompt("Entertain me!").await?;
168/// # Ok(())
169/// # }
170/// ```
171#[derive(Clone)]
172#[non_exhaustive]
173pub struct Agent<M, P = ()>
174where
175    M: CompletionModel,
176    P: PromptHook<M>,
177{
178    /// Name of the agent used for logging and debugging
179    pub name: Option<String>,
180    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
181    pub description: Option<String>,
182    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
183    pub model: Arc<M>,
184    /// System prompt
185    pub preamble: Option<String>,
186    /// Context documents always available to the agent
187    pub static_context: Vec<Document>,
188    /// Temperature of the model
189    pub temperature: Option<f64>,
190    /// Maximum number of tokens for the completion
191    pub max_tokens: Option<u64>,
192    /// Additional parameters to be passed to the model
193    pub additional_params: Option<serde_json::Value>,
194    pub tool_server_handle: ToolServerHandle,
195    /// List of vector store, with the sample number
196    pub dynamic_context: DynamicContextStore,
197    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
198    pub tool_choice: Option<ToolChoice>,
199    /// Default maximum depth for recursive agent calls
200    pub default_max_turns: Option<usize>,
201    /// Default hook for this agent, used when no per-request hook is provided
202    pub hook: Option<P>,
203    /// Optional JSON Schema for structured output. When set, providers that support
204    /// native structured outputs will constrain the model's response to match this schema.
205    pub output_schema: Option<schemars::Schema>,
206    /// Optional conversation memory backend that loads/saves history per conversation id.
207    pub memory: Option<Arc<dyn crate::memory::ConversationMemory>>,
208    /// Optional default conversation id used when none is set per-request.
209    pub default_conversation_id: Option<String>,
210}
211
212impl<M, P> Agent<M, P>
213where
214    M: CompletionModel,
215    P: PromptHook<M>,
216{
217    /// Returns the name of the agent.
218    pub(crate) fn name(&self) -> &str {
219        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
220    }
221}
222
223impl<M, P> Completion<M> for Agent<M, P>
224where
225    M: CompletionModel,
226    P: PromptHook<M>,
227{
228    async fn completion<I, T>(
229        &self,
230        prompt: impl Into<Message> + WasmCompatSend,
231        chat_history: I,
232    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
233    where
234        I: IntoIterator<Item = T>,
235        T: Into<Message>,
236    {
237        let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
238        build_completion_request(
239            &self.model,
240            prompt.into(),
241            &history,
242            self.preamble.as_deref(),
243            &self.static_context,
244            self.temperature,
245            self.max_tokens,
246            self.additional_params.as_ref(),
247            self.tool_choice.as_ref(),
248            &self.tool_server_handle,
249            &self.dynamic_context,
250            self.output_schema.as_ref(),
251        )
252        .await
253    }
254}
255
256// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
257//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
258//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
259//
260// References:
261//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
262
263#[allow(refining_impl_trait)]
264impl<M, P> Prompt for Agent<M, P>
265where
266    M: CompletionModel + 'static,
267    P: PromptHook<M> + 'static,
268{
269    fn prompt(
270        &self,
271        prompt: impl Into<Message> + WasmCompatSend,
272    ) -> PromptRequest<prompt_request::Standard, M, P> {
273        PromptRequest::from_agent(self, prompt)
274    }
275}
276
277#[allow(refining_impl_trait)]
278impl<M, P> Prompt for &Agent<M, P>
279where
280    M: CompletionModel + 'static,
281    P: PromptHook<M> + 'static,
282{
283    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
284    fn prompt(
285        &self,
286        prompt: impl Into<Message> + WasmCompatSend,
287    ) -> PromptRequest<prompt_request::Standard, M, P> {
288        PromptRequest::from_agent(*self, prompt)
289    }
290}
291
292#[allow(refining_impl_trait)]
293impl<M, P> Chat for Agent<M, P>
294where
295    M: CompletionModel + 'static,
296    P: PromptHook<M> + 'static,
297{
298    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
299    async fn chat(
300        &self,
301        prompt: impl Into<Message> + WasmCompatSend,
302        chat_history: &mut Vec<Message>,
303    ) -> Result<String, PromptError> {
304        let response = PromptRequest::from_agent(self, prompt)
305            .with_history(chat_history.clone())
306            .extended_details()
307            .await?;
308
309        if let Some(messages) = response.messages {
310            chat_history.extend(messages);
311        }
312
313        Ok(response.output)
314    }
315}
316
317impl<M, P> StreamingCompletion<M> for Agent<M, P>
318where
319    M: CompletionModel,
320    P: PromptHook<M>,
321{
322    async fn stream_completion<I, T>(
323        &self,
324        prompt: impl Into<Message> + WasmCompatSend,
325        chat_history: I,
326    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
327    where
328        I: IntoIterator<Item = T> + WasmCompatSend,
329        T: Into<Message>,
330    {
331        // Reuse the existing completion implementation to build the request
332        // This ensures streaming and non-streaming use the same request building logic
333        self.completion(prompt, chat_history).await
334    }
335}
336
337impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
338where
339    M: CompletionModel + 'static,
340    M::StreamingResponse: GetTokenUsage,
341    P: PromptHook<M> + 'static,
342{
343    type Hook = P;
344
345    fn stream_prompt(
346        &self,
347        prompt: impl Into<Message> + WasmCompatSend,
348    ) -> StreamingPromptRequest<M, P> {
349        StreamingPromptRequest::<M, P>::from_agent(self, prompt)
350    }
351}
352
353impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
354where
355    M: CompletionModel + 'static,
356    M::StreamingResponse: GetTokenUsage,
357    P: PromptHook<M> + 'static,
358{
359    type Hook = P;
360
361    fn stream_chat<I, T>(
362        &self,
363        prompt: impl Into<Message> + WasmCompatSend,
364        chat_history: I,
365    ) -> StreamingPromptRequest<M, P>
366    where
367        I: IntoIterator<Item = T>,
368        T: Into<Message>,
369    {
370        StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
371    }
372}
373
374use crate::agent::prompt_request::TypedPromptRequest;
375use schemars::JsonSchema;
376use serde::de::DeserializeOwned;
377
378#[allow(refining_impl_trait)]
379impl<M, P> TypedPrompt for Agent<M, P>
380where
381    M: CompletionModel + 'static,
382    P: PromptHook<M> + 'static,
383{
384    type TypedRequest<T>
385        = TypedPromptRequest<T, prompt_request::Standard, M, P>
386    where
387        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
388
389    /// Send a prompt and receive a typed structured response.
390    ///
391    /// The JSON schema for `T` is automatically generated and sent to the provider.
392    /// Providers that support native structured outputs will constrain the model's
393    /// response to match this schema.
394    ///
395    /// # Example
396    /// ```rust,ignore
397    /// use rig_core::prelude::*;
398    /// use schemars::JsonSchema;
399    /// use serde::Deserialize;
400    ///
401    /// #[derive(Debug, Deserialize, JsonSchema)]
402    /// struct WeatherForecast {
403    ///     city: String,
404    ///     temperature_f: f64,
405    ///     conditions: String,
406    /// }
407    ///
408    /// let agent = client.agent("gpt-4o").build();
409    ///
410    /// // Type inferred from variable
411    /// let forecast: WeatherForecast = agent
412    ///     .prompt_typed("What's the weather in NYC?")
413    ///     .await?;
414    ///
415    /// // Or explicit turbofish syntax
416    /// let forecast = agent
417    ///     .prompt_typed::<WeatherForecast>("What's the weather in NYC?")
418    ///     .max_turns(3)
419    ///     .await?;
420    /// ```
421    fn prompt_typed<T>(
422        &self,
423        prompt: impl Into<Message> + WasmCompatSend,
424    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
425    where
426        T: JsonSchema + DeserializeOwned + WasmCompatSend,
427    {
428        TypedPromptRequest::from_agent(self, prompt)
429    }
430}
431
432#[allow(refining_impl_trait)]
433impl<M, P> TypedPrompt for &Agent<M, P>
434where
435    M: CompletionModel + 'static,
436    P: PromptHook<M> + 'static,
437{
438    type TypedRequest<T>
439        = TypedPromptRequest<T, prompt_request::Standard, M, P>
440    where
441        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
442
443    fn prompt_typed<T>(
444        &self,
445        prompt: impl Into<Message> + WasmCompatSend,
446    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
447    where
448        T: JsonSchema + DeserializeOwned + WasmCompatSend,
449    {
450        TypedPromptRequest::from_agent(*self, prompt)
451    }
452}