Skip to main content

rig_core/agent/
completion.rs

1use super::prompt_request::{self, PromptRequest, hooks::PromptHook};
2use crate::{
3    agent::prompt_request::streaming::StreamingPromptRequest,
4    completion::{
5        Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
6        GetTokenUsage, Message, Prompt, PromptError, TypedPrompt,
7    },
8    message::ToolChoice,
9    streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
10    tool::server::ToolServerHandle,
11    vector_store::{VectorStoreError, request::VectorSearchRequest},
12    wasm_compat::WasmCompatSend,
13};
14use std::{
15    collections::{BTreeSet, HashMap},
16    sync::Arc,
17};
18
19const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
20
21pub type DynamicContextStore = Arc<
22    Vec<(
23        usize,
24        Arc<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
25    )>,
26>;
27
28/// A prepared completion request plus the executable Rig tool names advertised
29/// to the provider for this turn.
30pub(crate) struct PreparedCompletionRequest<M: CompletionModel> {
31    pub(crate) builder: CompletionRequestBuilder<M>,
32    pub(crate) executable_tool_names: BTreeSet<String>,
33    pub(crate) allowed_tool_names: BTreeSet<String>,
34}
35
36pub(crate) fn allowed_tool_names_for_choice(
37    executable_tool_names: &BTreeSet<String>,
38    tool_choice: Option<&ToolChoice>,
39) -> Result<BTreeSet<String>, CompletionError> {
40    let allowed = match tool_choice {
41        None | Some(ToolChoice::Auto | ToolChoice::Required) => executable_tool_names.clone(),
42        Some(ToolChoice::None) => BTreeSet::new(),
43        Some(ToolChoice::Specific { function_names }) => {
44            if function_names.is_empty() {
45                return Err(CompletionError::RequestError(
46                    "ToolChoice::Specific requires at least one function name".into(),
47                ));
48            }
49
50            let requested = function_names.iter().cloned().collect::<BTreeSet<String>>();
51            let missing = requested
52                .difference(executable_tool_names)
53                .cloned()
54                .collect::<Vec<_>>();
55
56            if !missing.is_empty() {
57                return Err(CompletionError::RequestError(
58                    format!(
59                        "ToolChoice::Specific requested unknown tool names: {missing:?}. Available tools: {:?}",
60                        executable_tool_names.iter().collect::<Vec<_>>()
61                    )
62                    .into(),
63                ));
64            }
65
66            requested
67        }
68    };
69
70    Ok(allowed)
71}
72
73/// Helper function to build a completion request from agent components.
74/// This is used by `Agent::completion()` to preserve the public completion API.
75#[allow(clippy::too_many_arguments)]
76pub(crate) async fn build_completion_request<M: CompletionModel>(
77    model: &Arc<M>,
78    prompt: Message,
79    chat_history: &[Message],
80    preamble: Option<&str>,
81    static_context: &[Document],
82    temperature: Option<f64>,
83    max_tokens: Option<u64>,
84    additional_params: Option<&serde_json::Value>,
85    tool_choice: Option<&ToolChoice>,
86    tool_server_handle: &ToolServerHandle,
87    dynamic_context: &DynamicContextStore,
88    output_schema: Option<&schemars::Schema>,
89) -> Result<CompletionRequestBuilder<M>, CompletionError> {
90    Ok(build_prepared_completion_request(
91        model,
92        prompt,
93        chat_history,
94        preamble,
95        static_context,
96        temperature,
97        max_tokens,
98        additional_params,
99        tool_choice,
100        tool_server_handle,
101        dynamic_context,
102        output_schema,
103    )
104    .await?
105    .builder)
106}
107
108/// Helper function to build a completion request from agent components while
109/// preserving the executable Rig tool names sent to the provider.
110#[allow(clippy::too_many_arguments)]
111pub(crate) async fn build_prepared_completion_request<M: CompletionModel>(
112    model: &Arc<M>,
113    prompt: Message,
114    chat_history: &[Message],
115    preamble: Option<&str>,
116    static_context: &[Document],
117    temperature: Option<f64>,
118    max_tokens: Option<u64>,
119    additional_params: Option<&serde_json::Value>,
120    tool_choice: Option<&ToolChoice>,
121    tool_server_handle: &ToolServerHandle,
122    dynamic_context: &DynamicContextStore,
123    output_schema: Option<&schemars::Schema>,
124) -> Result<PreparedCompletionRequest<M>, CompletionError> {
125    // Find the latest message in the chat history that contains RAG text
126    let rag_text = prompt.rag_text();
127    let rag_text = rag_text.or_else(|| {
128        chat_history
129            .iter()
130            .rev()
131            .find_map(|message| message.rag_text())
132    });
133
134    // Prepend preamble as system message if present
135    let chat_history: Vec<Message> = if let Some(preamble) = preamble {
136        std::iter::once(Message::system(preamble.to_owned()))
137            .chain(chat_history.iter().cloned())
138            .collect()
139    } else {
140        chat_history.to_vec()
141    };
142
143    let completion_request = model
144        .completion_request(prompt)
145        .messages(chat_history)
146        .temperature_opt(temperature)
147        .max_tokens_opt(max_tokens)
148        .additional_params_opt(additional_params.cloned())
149        .output_schema_opt(output_schema.cloned())
150        .documents(static_context.to_vec());
151
152    let completion_request = if let Some(tool_choice) = tool_choice {
153        completion_request.tool_choice(tool_choice.clone())
154    } else {
155        completion_request
156    };
157
158    // If the agent has RAG text, we need to fetch the dynamic context and tools
159    let (builder, executable_tool_names) = match &rag_text {
160        Some(text) => {
161            // Map over the vector to create async tasks
162            let search_futures = dynamic_context.iter().map(|(num_sample, index)| {
163                // Clone values to move into the async block
164                let text = text.clone();
165                let num_sample = *num_sample;
166                let index = index.clone();
167
168                async move {
169                    let req = VectorSearchRequest::builder()
170                        .query(text)
171                        .samples(num_sample as u64)
172                        .build();
173
174                    let docs = index
175                        .top_n(req)
176                        .await?
177                        .into_iter()
178                        .map(|(_, id, doc)| {
179                            // Pretty print the document if possible for better readability
180                            let text = serde_json::to_string_pretty(&doc)
181                                .unwrap_or_else(|_| doc.to_string());
182
183                            Document {
184                                id,
185                                text,
186                                additional_props: HashMap::new(),
187                            }
188                        })
189                        .collect::<Vec<_>>();
190
191                    Ok::<_, VectorStoreError>(docs)
192                }
193            });
194
195            // Await all vector searches concurrently
196            let fetched_context: Vec<Document> = futures::future::try_join_all(search_futures)
197                .await
198                .map_err(|e| CompletionError::RequestError(Box::new(e)))?
199                .into_iter()
200                .flatten() // Flatten the Vec<Vec<Document>> into Vec<Document>
201                .collect();
202
203            let tooldefs = tool_server_handle
204                .get_tool_defs(Some(text.to_string()))
205                .await
206                .map_err(|_| {
207                    CompletionError::RequestError("Failed to get tool definitions".into())
208                })?;
209            let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect();
210
211            (
212                completion_request
213                    .documents(fetched_context)
214                    .tools(tooldefs),
215                executable_tool_names,
216            )
217        }
218        None => {
219            let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
220                CompletionError::RequestError("Failed to get tool definitions".into())
221            })?;
222            let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect();
223
224            (completion_request.tools(tooldefs), executable_tool_names)
225        }
226    };
227    let allowed_tool_names = allowed_tool_names_for_choice(&executable_tool_names, tool_choice)?;
228
229    Ok(PreparedCompletionRequest {
230        builder,
231        executable_tool_names,
232        allowed_tool_names,
233    })
234}
235
236/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
237/// (i.e.: system prompt) and a static set of context documents and tools.
238/// All context documents and tools are always provided to the agent when prompted.
239///
240/// The optional type parameter `P` represents a default hook that will be used for all
241/// prompt requests unless overridden via `.with_hook()` on the request.
242///
243/// # Example
244/// ```no_run
245/// use rig_core::{
246///     client::{CompletionClient, ProviderClient},
247///     completion::Prompt,
248///     providers::openai,
249/// };
250///
251/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
252/// let openai = openai::Client::from_env()?;
253///
254/// let comedian_agent = openai
255///     .agent(openai::GPT_5_2)
256///     .preamble("You are a comedian here to entertain the user using humour and jokes.")
257///     .temperature(0.9)
258///     .build();
259///
260/// let response = comedian_agent.prompt("Entertain me!").await?;
261/// # Ok(())
262/// # }
263/// ```
264#[derive(Clone)]
265#[non_exhaustive]
266pub struct Agent<M, P = ()>
267where
268    M: CompletionModel,
269    P: PromptHook<M>,
270{
271    /// Name of the agent used for logging and debugging
272    pub name: Option<String>,
273    /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats.
274    pub description: Option<String>,
275    /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r)
276    pub model: Arc<M>,
277    /// System prompt
278    pub preamble: Option<String>,
279    /// Context documents always available to the agent
280    pub static_context: Vec<Document>,
281    /// Temperature of the model
282    pub temperature: Option<f64>,
283    /// Maximum number of tokens for the completion
284    pub max_tokens: Option<u64>,
285    /// Additional parameters to be passed to the model
286    pub additional_params: Option<serde_json::Value>,
287    pub tool_server_handle: ToolServerHandle,
288    /// List of vector store, with the sample number
289    pub dynamic_context: DynamicContextStore,
290    /// Whether or not the underlying LLM should be forced to use a tool before providing a response.
291    pub tool_choice: Option<ToolChoice>,
292    /// Default maximum depth for recursive agent calls
293    pub default_max_turns: Option<usize>,
294    /// Default hook for this agent, used when no per-request hook is provided
295    pub hook: Option<P>,
296    /// Optional JSON Schema for structured output. When set, providers that support
297    /// native structured outputs will constrain the model's response to match this schema.
298    pub output_schema: Option<schemars::Schema>,
299    /// Optional conversation memory backend that loads/saves history per conversation id.
300    pub memory: Option<Arc<dyn crate::memory::ConversationMemory>>,
301    /// Optional default conversation id used when none is set per-request.
302    pub default_conversation_id: Option<String>,
303}
304
305impl<M, P> Agent<M, P>
306where
307    M: CompletionModel,
308    P: PromptHook<M>,
309{
310    /// Returns the name of the agent.
311    pub(crate) fn name(&self) -> &str {
312        self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
313    }
314}
315
316impl<M, P> Completion<M> for Agent<M, P>
317where
318    M: CompletionModel,
319    P: PromptHook<M>,
320{
321    async fn completion<I, T>(
322        &self,
323        prompt: impl Into<Message> + WasmCompatSend,
324        chat_history: I,
325    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
326    where
327        I: IntoIterator<Item = T>,
328        T: Into<Message>,
329    {
330        let history: Vec<Message> = chat_history.into_iter().map(Into::into).collect();
331        build_completion_request(
332            &self.model,
333            prompt.into(),
334            &history,
335            self.preamble.as_deref(),
336            &self.static_context,
337            self.temperature,
338            self.max_tokens,
339            self.additional_params.as_ref(),
340            self.tool_choice.as_ref(),
341            &self.tool_server_handle,
342            &self.dynamic_context,
343            self.output_schema.as_ref(),
344        )
345        .await
346    }
347}
348
349// Here, we need to ensure that usage of `.prompt` on agent uses these redefinitions on the opaque
350//  `Prompt` trait so that when `.prompt` is used at the call-site, it'll use the more specific
351//  `PromptRequest` implementation for `Agent`, making the builder's usage fluent.
352//
353// References:
354//  - https://github.com/rust-lang/rust/issues/121718 (refining_impl_trait)
355
356#[allow(refining_impl_trait)]
357impl<M, P> Prompt for Agent<M, P>
358where
359    M: CompletionModel + 'static,
360    P: PromptHook<M> + 'static,
361{
362    fn prompt(
363        &self,
364        prompt: impl Into<Message> + WasmCompatSend,
365    ) -> PromptRequest<prompt_request::Standard, M, P> {
366        PromptRequest::from_agent(self, prompt)
367    }
368}
369
370#[allow(refining_impl_trait)]
371impl<M, P> Prompt for &Agent<M, P>
372where
373    M: CompletionModel + 'static,
374    P: PromptHook<M> + 'static,
375{
376    #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
377    fn prompt(
378        &self,
379        prompt: impl Into<Message> + WasmCompatSend,
380    ) -> PromptRequest<prompt_request::Standard, M, P> {
381        PromptRequest::from_agent(*self, prompt)
382    }
383}
384
385#[allow(refining_impl_trait)]
386impl<M, P> Chat for Agent<M, P>
387where
388    M: CompletionModel + 'static,
389    P: PromptHook<M> + 'static,
390{
391    #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
392    async fn chat(
393        &self,
394        prompt: impl Into<Message> + WasmCompatSend,
395        chat_history: &mut Vec<Message>,
396    ) -> Result<String, PromptError> {
397        let response = PromptRequest::from_agent(self, prompt)
398            .with_history(chat_history.clone())
399            .extended_details()
400            .await?;
401
402        if let Some(messages) = response.messages {
403            chat_history.extend(messages);
404        }
405
406        Ok(response.output)
407    }
408}
409
410impl<M, P> StreamingCompletion<M> for Agent<M, P>
411where
412    M: CompletionModel,
413    P: PromptHook<M>,
414{
415    async fn stream_completion<I, T>(
416        &self,
417        prompt: impl Into<Message> + WasmCompatSend,
418        chat_history: I,
419    ) -> Result<CompletionRequestBuilder<M>, CompletionError>
420    where
421        I: IntoIterator<Item = T> + WasmCompatSend,
422        T: Into<Message>,
423    {
424        // Reuse the existing completion implementation to build the request
425        // This ensures streaming and non-streaming use the same request building logic
426        self.completion(prompt, chat_history).await
427    }
428}
429
430impl<M, P> StreamingPrompt<M, M::StreamingResponse> for Agent<M, P>
431where
432    M: CompletionModel + 'static,
433    M::StreamingResponse: GetTokenUsage,
434    P: PromptHook<M> + 'static,
435{
436    type Hook = P;
437
438    fn stream_prompt(
439        &self,
440        prompt: impl Into<Message> + WasmCompatSend,
441    ) -> StreamingPromptRequest<M, P> {
442        StreamingPromptRequest::<M, P>::from_agent(self, prompt)
443    }
444}
445
446impl<M, P> StreamingChat<M, M::StreamingResponse> for Agent<M, P>
447where
448    M: CompletionModel + 'static,
449    M::StreamingResponse: GetTokenUsage,
450    P: PromptHook<M> + 'static,
451{
452    type Hook = P;
453
454    fn stream_chat<I, T>(
455        &self,
456        prompt: impl Into<Message> + WasmCompatSend,
457        chat_history: I,
458    ) -> StreamingPromptRequest<M, P>
459    where
460        I: IntoIterator<Item = T>,
461        T: Into<Message>,
462    {
463        StreamingPromptRequest::<M, P>::from_agent(self, prompt).with_history(chat_history)
464    }
465}
466
467use crate::agent::prompt_request::TypedPromptRequest;
468use schemars::JsonSchema;
469use serde::de::DeserializeOwned;
470
471#[allow(refining_impl_trait)]
472impl<M, P> TypedPrompt for Agent<M, P>
473where
474    M: CompletionModel + 'static,
475    P: PromptHook<M> + 'static,
476{
477    type TypedRequest<T>
478        = TypedPromptRequest<T, prompt_request::Standard, M, P>
479    where
480        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
481
482    /// Send a prompt and receive a typed structured response.
483    ///
484    /// The JSON schema for `T` is automatically generated and sent to the provider.
485    /// Providers that support native structured outputs will constrain the model's
486    /// response to match this schema.
487    ///
488    /// # Example
489    /// ```rust,ignore
490    /// use rig_core::prelude::*;
491    /// use schemars::JsonSchema;
492    /// use serde::Deserialize;
493    ///
494    /// #[derive(Debug, Deserialize, JsonSchema)]
495    /// struct WeatherForecast {
496    ///     city: String,
497    ///     temperature_f: f64,
498    ///     conditions: String,
499    /// }
500    ///
501    /// let agent = client.agent("gpt-4o").build();
502    ///
503    /// // Type inferred from variable
504    /// let forecast: WeatherForecast = agent
505    ///     .prompt_typed("What's the weather in NYC?")
506    ///     .await?;
507    ///
508    /// // Or explicit turbofish syntax
509    /// let forecast = agent
510    ///     .prompt_typed::<WeatherForecast>("What's the weather in NYC?")
511    ///     .max_turns(3)
512    ///     .await?;
513    /// ```
514    fn prompt_typed<T>(
515        &self,
516        prompt: impl Into<Message> + WasmCompatSend,
517    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
518    where
519        T: JsonSchema + DeserializeOwned + WasmCompatSend,
520    {
521        TypedPromptRequest::from_agent(self, prompt)
522    }
523}
524
525#[allow(refining_impl_trait)]
526impl<M, P> TypedPrompt for &Agent<M, P>
527where
528    M: CompletionModel + 'static,
529    P: PromptHook<M> + 'static,
530{
531    type TypedRequest<T>
532        = TypedPromptRequest<T, prompt_request::Standard, M, P>
533    where
534        T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
535
536    fn prompt_typed<T>(
537        &self,
538        prompt: impl Into<Message> + WasmCompatSend,
539    ) -> TypedPromptRequest<T, prompt_request::Standard, M, P>
540    where
541        T: JsonSchema + DeserializeOwned + WasmCompatSend,
542    {
543        TypedPromptRequest::from_agent(*self, prompt)
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    fn tool_names(names: &[&str]) -> BTreeSet<String> {
552        names.iter().map(|name| (*name).to_string()).collect()
553    }
554
555    #[test]
556    fn allowed_tool_names_defaults_to_all_executable_tools() {
557        let executable = tool_names(&["add", "subtract"]);
558
559        assert_eq!(
560            allowed_tool_names_for_choice(&executable, None).unwrap(),
561            executable
562        );
563    }
564
565    #[test]
566    fn allowed_tool_names_auto_and_required_allow_all_executable_tools() {
567        let executable = tool_names(&["add", "subtract"]);
568
569        assert_eq!(
570            allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Auto)).unwrap(),
571            executable
572        );
573        assert_eq!(
574            allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Required)).unwrap(),
575            executable
576        );
577    }
578
579    #[test]
580    fn allowed_tool_names_none_allows_no_tools() {
581        let executable = tool_names(&["add", "subtract"]);
582
583        assert!(
584            allowed_tool_names_for_choice(&executable, Some(&ToolChoice::None))
585                .unwrap()
586                .is_empty()
587        );
588    }
589
590    #[test]
591    fn allowed_tool_names_specific_allows_requested_executable_tools() {
592        let executable = tool_names(&["add", "subtract"]);
593        let choice = ToolChoice::Specific {
594            function_names: vec!["add".to_string()],
595        };
596
597        assert_eq!(
598            allowed_tool_names_for_choice(&executable, Some(&choice)).unwrap(),
599            tool_names(&["add"])
600        );
601    }
602
603    #[test]
604    fn allowed_tool_names_specific_rejects_missing_tools() {
605        let executable = tool_names(&["add"]);
606        let choice = ToolChoice::Specific {
607            function_names: vec!["missing".to_string()],
608        };
609
610        let err = allowed_tool_names_for_choice(&executable, Some(&choice))
611            .expect_err("missing specific tool should fail before provider request");
612
613        assert!(matches!(
614            err,
615            CompletionError::RequestError(err)
616                if err.to_string().contains("missing")
617                    && err.to_string().contains("add")
618        ));
619    }
620
621    #[test]
622    fn allowed_tool_names_specific_rejects_empty_names() {
623        let executable = tool_names(&["add"]);
624        let choice = ToolChoice::Specific {
625            function_names: vec![],
626        };
627
628        let err = allowed_tool_names_for_choice(&executable, Some(&choice))
629            .expect_err("empty specific tool choice should fail before provider request");
630
631        assert!(matches!(
632            err,
633            CompletionError::RequestError(err)
634                if err.to_string().contains("requires at least one function name")
635        ));
636    }
637}