Skip to main content

rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76    json_utils,
77    message::{Message, UserContent},
78    tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87// Errors
88#[derive(Debug, Error)]
89pub enum CompletionError {
90    /// Http error (e.g.: connection error, timeout, etc.)
91    #[error("HttpError: {0}")]
92    HttpError(#[from] http_client::Error),
93
94    /// Json error (e.g.: serialization, deserialization)
95    #[error("JsonError: {0}")]
96    JsonError(#[from] serde_json::Error),
97
98    /// Url error (e.g.: invalid URL)
99    #[error("UrlError: {0}")]
100    UrlError(#[from] url::ParseError),
101
102    #[cfg(not(target_family = "wasm"))]
103    /// Error building the completion request
104    #[error("RequestError: {0}")]
105    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107    #[cfg(target_family = "wasm")]
108    /// Error building the completion request
109    #[error("RequestError: {0}")]
110    RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112    /// Error parsing the completion response
113    #[error("ResponseError: {0}")]
114    ResponseError(String),
115
116    /// Error returned by the completion model provider
117    #[error("ProviderError: {0}")]
118    ProviderError(String),
119}
120
121/// Prompt errors
122#[derive(Debug, Error)]
123pub enum PromptError {
124    /// Something went wrong with the completion
125    #[error("CompletionError: {0}")]
126    CompletionError(#[from] CompletionError),
127
128    /// There was an error while using a tool
129    #[error("ToolCallError: {0}")]
130    ToolError(#[from] ToolSetError),
131
132    /// There was an issue while executing a tool on a tool server
133    #[error("ToolServerError: {0}")]
134    ToolServerError(#[from] Box<ToolServerError>),
135
136    /// The LLM tried to call too many tools during a multi-turn conversation.
137    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
138    /// or increase the amount of turns given in `.multi_turn()`.
139    #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
140    MaxTurnsError {
141        max_turns: usize,
142        chat_history: Box<Vec<Message>>,
143        prompt: Box<Message>,
144    },
145
146    /// A prompting loop was cancelled.
147    #[error("PromptCancelled: {reason}")]
148    PromptCancelled {
149        chat_history: Vec<Message>,
150        reason: String,
151    },
152}
153
154impl PromptError {
155    pub(crate) fn prompt_cancelled(
156        chat_history: impl IntoIterator<Item = Message>,
157        reason: impl Into<String>,
158    ) -> Self {
159        Self::PromptCancelled {
160            chat_history: chat_history.into_iter().collect(),
161            reason: reason.into(),
162        }
163    }
164}
165
166/// Errors that can occur when using typed structured output via [`TypedPrompt::prompt_typed`].
167#[derive(Debug, Error)]
168pub enum StructuredOutputError {
169    /// An error occurred during the prompt execution.
170    #[error("PromptError: {0}")]
171    PromptError(#[from] Box<PromptError>),
172
173    /// Failed to deserialize the model's response into the target type.
174    #[error("DeserializationError: {0}")]
175    DeserializationError(#[from] serde_json::Error),
176
177    /// The model returned an empty response.
178    #[error("EmptyResponse: model returned no content")]
179    EmptyResponse,
180}
181
182#[derive(Clone, Debug, Deserialize, Serialize)]
183pub struct Document {
184    pub id: String,
185    pub text: String,
186    #[serde(flatten)]
187    pub additional_props: HashMap<String, String>,
188}
189
190impl std::fmt::Display for Document {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        write!(
193            f,
194            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
195            self.id,
196            if self.additional_props.is_empty() {
197                self.text.clone()
198            } else {
199                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
200                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
201                let metadata = sorted_props
202                    .iter()
203                    .map(|(k, v)| format!("{k}: {v:?}"))
204                    .collect::<Vec<_>>()
205                    .join(" ");
206                format!("<metadata {} />\n{}", metadata, self.text)
207            }
208        )
209    }
210}
211
212#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
213pub struct ToolDefinition {
214    pub name: String,
215    pub description: String,
216    pub parameters: serde_json::Value,
217}
218
219/// Provider-native tool definition.
220///
221/// Stored under `additional_params.tools` and forwarded by providers that support
222/// provider-managed tools.
223#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
224pub struct ProviderToolDefinition {
225    /// Tool type/kind name as expected by the target provider (for example `web_search`).
226    #[serde(rename = "type")]
227    pub kind: String,
228    /// Additional provider-specific configuration for this hosted tool.
229    #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
230    pub config: serde_json::Map<String, serde_json::Value>,
231}
232
233impl ProviderToolDefinition {
234    /// Creates a provider-hosted tool definition by type.
235    pub fn new(kind: impl Into<String>) -> Self {
236        Self {
237            kind: kind.into(),
238            config: serde_json::Map::new(),
239        }
240    }
241
242    /// Adds a provider-specific configuration key/value.
243    pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
244        self.config.insert(key.into(), value);
245        self
246    }
247}
248
249// ================================================================
250// Implementations
251// ================================================================
252/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
253pub trait Prompt: WasmCompatSend + WasmCompatSync {
254    /// Send a simple prompt to the underlying completion model.
255    ///
256    /// If the completion model's response is a message, then it is returned as a string.
257    ///
258    /// If the completion model's response is a tool call, then the tool is called and
259    /// the result is returned as a string.
260    ///
261    /// If the tool does not exist, or the tool call fails, then an error is returned.
262    fn prompt(
263        &self,
264        prompt: impl Into<Message> + WasmCompatSend,
265    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
266}
267
268/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
269pub trait Chat: WasmCompatSend + WasmCompatSync {
270    /// Send a prompt with optional chat history to the underlying completion model.
271    ///
272    /// If the completion model's response is a message, then it is returned as a string.
273    ///
274    /// If the completion model's response is a tool call, then the tool is called and the result
275    /// is returned as a string.
276    ///
277    /// If the tool does not exist, or the tool call fails, then an error is returned.
278    fn chat<I, T>(
279        &self,
280        prompt: impl Into<Message> + WasmCompatSend,
281        chat_history: I,
282    ) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend
283    where
284        I: IntoIterator<Item = T> + WasmCompatSend,
285        T: Into<Message>;
286}
287
288/// Trait defining a high-level typed prompt interface for structured output.
289///
290/// This trait provides an ergonomic way to get typed responses from an LLM by automatically
291/// generating a JSON schema from the target type and deserializing the response.
292///
293/// # Example
294/// ```rust,ignore
295/// use rig::prelude::*;
296/// use schemars::JsonSchema;
297/// use serde::Deserialize;
298///
299/// #[derive(Debug, Deserialize, JsonSchema)]
300/// struct WeatherForecast {
301///     city: String,
302///     temperature_f: f64,
303///     conditions: String,
304/// }
305///
306/// let agent = client.agent("gpt-4o").build();
307/// let forecast: WeatherForecast = agent
308///     .prompt_typed("What's the weather in NYC?")
309///     .await?;
310/// ```
311pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
312    /// The type of the typed prompt request returned by `prompt_typed`.
313    type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
314    where
315        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
316
317    /// Send a prompt and receive a typed structured response.
318    ///
319    /// The JSON schema for `T` is automatically generated and sent to the provider.
320    /// Providers that support native structured outputs will constrain the model's
321    /// response to match this schema.
322    ///
323    /// # Type Parameters
324    /// * `T` - The target type to deserialize the response into. Must implement
325    ///   `JsonSchema` (for schema generation), `DeserializeOwned` (for deserialization),
326    ///   and `WasmCompatSend` (for async compatibility).
327    ///
328    /// # Example
329    /// ```rust,ignore
330    /// // Type can be inferred
331    /// let forecast: WeatherForecast = agent.prompt_typed("What's the weather?").await?;
332    ///
333    /// // Or specified explicitly with turbofish
334    /// let forecast = agent.prompt_typed::<WeatherForecast>("What's the weather?").await?;
335    /// ```
336    fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
337    where
338        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
339}
340
341/// Trait defining a low-level LLM completion interface
342pub trait Completion<M: CompletionModel> {
343    /// Generates a completion request builder for the given `prompt` and `chat_history`.
344    /// This function is meant to be called by the user to further customize the
345    /// request at prompt time before sending it.
346    ///
347    /// ❗IMPORTANT: The type that implements this trait might have already
348    /// populated fields in the builder (the exact fields depend on the type).
349    /// For fields that have already been set by the model, calling the corresponding
350    /// method on the builder will overwrite the value set by the model.
351    ///
352    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
353    /// contain the `preamble` provided when creating the agent.
354    fn completion<I, T>(
355        &self,
356        prompt: impl Into<Message> + WasmCompatSend,
357        chat_history: I,
358    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
359    + WasmCompatSend
360    where
361        I: IntoIterator<Item = T> + WasmCompatSend,
362        T: Into<Message>;
363}
364
365/// General completion response struct that contains the high-level completion choice
366/// and the raw response. The completion choice contains one or more assistant content.
367#[derive(Debug)]
368pub struct CompletionResponse<T> {
369    /// The completion choice (represented by one or more assistant message content)
370    /// returned by the completion model provider
371    pub choice: OneOrMany<AssistantContent>,
372    /// Tokens used during prompting and responding
373    pub usage: Usage,
374    /// The raw response returned by the completion model provider
375    pub raw_response: T,
376    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
377    /// Used to pair reasoning input items with their output items in multi-turn.
378    pub message_id: Option<String>,
379}
380
381/// A trait for grabbing the token usage of a completion response.
382///
383/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
384pub trait GetTokenUsage {
385    fn token_usage(&self) -> Option<crate::completion::Usage>;
386}
387
388impl GetTokenUsage for () {
389    fn token_usage(&self) -> Option<crate::completion::Usage> {
390        None
391    }
392}
393
394impl<T> GetTokenUsage for Option<T>
395where
396    T: GetTokenUsage,
397{
398    fn token_usage(&self) -> Option<crate::completion::Usage> {
399        if let Some(usage) = self {
400            usage.token_usage()
401        } else {
402            None
403        }
404    }
405}
406
407/// Struct representing the token usage for a completion request.
408/// If tokens used are `0`, then the provider failed to supply token usage metrics.
409#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
410pub struct Usage {
411    /// The number of input ("prompt") tokens used in a given request.
412    pub input_tokens: u64,
413    /// The number of output ("completion") tokens used in a given request.
414    pub output_tokens: u64,
415    /// We store this separately as some providers may only report one number
416    pub total_tokens: u64,
417    /// The number of input tokens read from a provider-managed cache
418    pub cached_input_tokens: u64,
419    /// The number of input tokens written to a provider-managed cache
420    pub cache_creation_input_tokens: u64,
421}
422
423impl Usage {
424    /// Creates a new instance of `Usage`.
425    pub fn new() -> Self {
426        Self {
427            input_tokens: 0,
428            output_tokens: 0,
429            total_tokens: 0,
430            cached_input_tokens: 0,
431            cache_creation_input_tokens: 0,
432        }
433    }
434}
435
436impl Default for Usage {
437    fn default() -> Self {
438        Self::new()
439    }
440}
441
442impl Add for Usage {
443    type Output = Self;
444
445    fn add(self, other: Self) -> Self::Output {
446        Self {
447            input_tokens: self.input_tokens + other.input_tokens,
448            output_tokens: self.output_tokens + other.output_tokens,
449            total_tokens: self.total_tokens + other.total_tokens,
450            cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
451            cache_creation_input_tokens: self.cache_creation_input_tokens
452                + other.cache_creation_input_tokens,
453        }
454    }
455}
456
457impl AddAssign for Usage {
458    fn add_assign(&mut self, other: Self) {
459        self.input_tokens += other.input_tokens;
460        self.output_tokens += other.output_tokens;
461        self.total_tokens += other.total_tokens;
462        self.cached_input_tokens += other.cached_input_tokens;
463        self.cache_creation_input_tokens += other.cache_creation_input_tokens;
464    }
465}
466
467/// Trait defining a completion model that can be used to generate completion responses.
468/// This trait is meant to be implemented by the user to define a custom completion model,
469/// either from a third party provider (e.g.: OpenAI) or a local model.
470pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
471    /// The raw response type returned by the underlying completion model.
472    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
473    /// The raw response type returned by the underlying completion model when streaming.
474    type StreamingResponse: Clone
475        + Unpin
476        + WasmCompatSend
477        + WasmCompatSync
478        + Serialize
479        + DeserializeOwned
480        + GetTokenUsage;
481
482    type Client;
483
484    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
485
486    /// Generates a completion response for the given completion request.
487    fn completion(
488        &self,
489        request: CompletionRequest,
490    ) -> impl std::future::Future<
491        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
492    > + WasmCompatSend;
493
494    fn stream(
495        &self,
496        request: CompletionRequest,
497    ) -> impl std::future::Future<
498        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
499    > + WasmCompatSend;
500
501    /// Generates a completion request builder for the given `prompt`.
502    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
503        CompletionRequestBuilder::new(self.clone(), prompt)
504    }
505}
506
507#[allow(deprecated)]
508#[deprecated(
509    since = "0.25.0",
510    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
511)]
512pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
513    fn completion(
514        &self,
515        request: CompletionRequest,
516    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
517
518    fn stream(
519        &self,
520        request: CompletionRequest,
521    ) -> WasmBoxedFuture<
522        '_,
523        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
524    >;
525
526    fn completion_request(
527        &self,
528        prompt: Message,
529    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
530}
531
532#[allow(deprecated)]
533impl<T, R> CompletionModelDyn for T
534where
535    T: CompletionModel<StreamingResponse = R>,
536    R: Clone + Unpin + GetTokenUsage + 'static,
537{
538    fn completion(
539        &self,
540        request: CompletionRequest,
541    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
542        Box::pin(async move {
543            self.completion(request)
544                .await
545                .map(|resp| CompletionResponse {
546                    choice: resp.choice,
547                    usage: resp.usage,
548                    raw_response: (),
549                    message_id: resp.message_id,
550                })
551        })
552    }
553
554    fn stream(
555        &self,
556        request: CompletionRequest,
557    ) -> WasmBoxedFuture<
558        '_,
559        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
560    > {
561        Box::pin(async move {
562            let resp = self.stream(request).await?;
563            let inner = resp.inner;
564
565            let stream = streaming::StreamingResultDyn {
566                inner: Box::pin(inner),
567            };
568
569            Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
570        })
571    }
572
573    /// Generates a completion request builder for the given `prompt`.
574    fn completion_request(
575        &self,
576        prompt: Message,
577    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
578        CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
579    }
580}
581
582/// Struct representing a general completion request that can be sent to a completion model provider.
583#[derive(Debug, Clone)]
584pub struct CompletionRequest {
585    /// Optional model override for this request.
586    pub model: Option<String>,
587    /// Legacy preamble field preserved for backwards compatibility.
588    ///
589    /// New code should prefer a leading [`Message::System`]
590    /// in `chat_history` as the canonical representation of system instructions.
591    pub preamble: Option<String>,
592    /// The chat history to be sent to the completion model provider.
593    /// The very last message will always be the prompt (hence why there is *always* one)
594    pub chat_history: OneOrMany<Message>,
595    /// The documents to be sent to the completion model provider
596    pub documents: Vec<Document>,
597    /// The tools to be sent to the completion model provider
598    pub tools: Vec<ToolDefinition>,
599    /// The temperature to be sent to the completion model provider
600    pub temperature: Option<f64>,
601    /// The max tokens to be sent to the completion model provider
602    pub max_tokens: Option<u64>,
603    /// Whether tools are required to be used by the model provider or not before providing a response.
604    pub tool_choice: Option<ToolChoice>,
605    /// Additional provider-specific parameters to be sent to the completion model provider
606    pub additional_params: Option<serde_json::Value>,
607    /// Optional JSON Schema for structured output. When set, providers that support
608    /// native structured outputs will constrain the model's response to match this schema.
609    pub output_schema: Option<schemars::Schema>,
610}
611
612impl CompletionRequest {
613    /// Extracts a name from the output schema's `"title"` field, falling back to `"response_schema"`.
614    /// Useful for providers that require a name alongside the JSON Schema (e.g., OpenAI).
615    pub fn output_schema_name(&self) -> Option<String> {
616        self.output_schema.as_ref().map(|schema| {
617            schema
618                .as_object()
619                .and_then(|o| o.get("title"))
620                .and_then(|v| v.as_str())
621                .unwrap_or("response_schema")
622                .to_string()
623        })
624    }
625
626    /// Returns documents normalized into a message (if any).
627    /// Most providers do not accept documents directly as input, so it needs to convert into a
628    ///  `Message` so that it can be incorporated into `chat_history` as a
629    pub fn normalized_documents(&self) -> Option<Message> {
630        if self.documents.is_empty() {
631            return None;
632        }
633
634        // Most providers will convert documents into a text unless it can handle document messages.
635        // We use `UserContent::document` for those who handle it directly!
636        let messages = self
637            .documents
638            .iter()
639            .map(|doc| {
640                UserContent::document(
641                    doc.to_string(),
642                    // In the future, we can customize `Document` to pass these extra types through.
643                    // Most providers ditch these but they might want to use them.
644                    Some(DocumentMediaType::TXT),
645                )
646            })
647            .collect::<Vec<_>>();
648
649        Some(Message::User {
650            content: OneOrMany::many(messages).expect("There will be atleast one document"),
651        })
652    }
653
654    /// Adds a provider-hosted tool by storing it in `additional_params.tools`.
655    pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
656        self.additional_params =
657            merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
658        self
659    }
660
661    /// Adds provider-hosted tools by storing them in `additional_params.tools`.
662    pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
663        self.additional_params =
664            merge_provider_tools_into_additional_params(self.additional_params, tools);
665        self
666    }
667}
668
669fn merge_provider_tools_into_additional_params(
670    additional_params: Option<serde_json::Value>,
671    provider_tools: Vec<ProviderToolDefinition>,
672) -> Option<serde_json::Value> {
673    if provider_tools.is_empty() {
674        return additional_params;
675    }
676
677    let mut provider_tools_json = provider_tools
678        .into_iter()
679        .map(|ProviderToolDefinition { kind, mut config }| {
680            // Force the provider tool type from the strongly-typed field.
681            config.insert("type".to_string(), serde_json::Value::String(kind));
682            serde_json::Value::Object(config)
683        })
684        .collect::<Vec<_>>();
685
686    let mut params_map = match additional_params {
687        Some(serde_json::Value::Object(map)) => map,
688        Some(serde_json::Value::Bool(stream)) => {
689            let mut map = serde_json::Map::new();
690            map.insert("stream".to_string(), serde_json::Value::Bool(stream));
691            map
692        }
693        _ => serde_json::Map::new(),
694    };
695
696    let mut merged_tools = match params_map.remove("tools") {
697        Some(serde_json::Value::Array(existing)) => existing,
698        _ => Vec::new(),
699    };
700    merged_tools.append(&mut provider_tools_json);
701    params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
702    Some(serde_json::Value::Object(params_map))
703}
704
705/// Builder struct for constructing a completion request.
706///
707/// Example usage:
708/// ```rust
709/// use rig::{
710///     providers::openai::{Client, self},
711///     completion::CompletionRequestBuilder,
712/// };
713///
714/// let openai = Client::new("your-openai-api-key");
715/// let model = openai.completion_model(openai::GPT_4O).build();
716///
717/// // Create the completion request and execute it separately
718/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
719///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
720///     .temperature(0.5)
721///     .build();
722///
723/// let response = model.completion(request)
724///     .await
725///     .expect("Failed to get completion response");
726/// ```
727///
728/// Alternatively, you can execute the completion request directly from the builder:
729/// ```rust
730/// use rig::{
731///     providers::openai::{Client, self},
732///     completion::CompletionRequestBuilder,
733/// };
734///
735/// let openai = Client::new("your-openai-api-key");
736/// let model = openai.completion_model(openai::GPT_4O).build();
737///
738/// // Create the completion request and execute it directly
739/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
740///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
741///     .temperature(0.5)
742///     .send()
743///     .await
744///     .expect("Failed to get completion response");
745/// ```
746///
747/// Note: It is usually unnecessary to create a completion request builder directly.
748/// Instead, use the [CompletionModel::completion_request] method.
749pub struct CompletionRequestBuilder<M: CompletionModel> {
750    model: M,
751    prompt: Message,
752    request_model: Option<String>,
753    preamble: Option<String>,
754    chat_history: Vec<Message>,
755    documents: Vec<Document>,
756    tools: Vec<ToolDefinition>,
757    provider_tools: Vec<ProviderToolDefinition>,
758    temperature: Option<f64>,
759    max_tokens: Option<u64>,
760    tool_choice: Option<ToolChoice>,
761    additional_params: Option<serde_json::Value>,
762    output_schema: Option<schemars::Schema>,
763}
764
765impl<M: CompletionModel> CompletionRequestBuilder<M> {
766    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
767        Self {
768            model,
769            prompt: prompt.into(),
770            request_model: None,
771            preamble: None,
772            chat_history: Vec::new(),
773            documents: Vec::new(),
774            tools: Vec::new(),
775            provider_tools: Vec::new(),
776            temperature: None,
777            max_tokens: None,
778            tool_choice: None,
779            additional_params: None,
780            output_schema: None,
781        }
782    }
783
784    /// Sets the preamble for the completion request.
785    pub fn preamble(mut self, preamble: String) -> Self {
786        // Legacy public API: funnel preamble into canonical system messages at build-time.
787        self.preamble = Some(preamble);
788        self
789    }
790
791    /// Overrides the model used for this request.
792    pub fn model(mut self, model: impl Into<String>) -> Self {
793        self.request_model = Some(model.into());
794        self
795    }
796
797    /// Overrides the model used for this request.
798    pub fn model_opt(mut self, model: Option<String>) -> Self {
799        self.request_model = model;
800        self
801    }
802
803    pub fn without_preamble(mut self) -> Self {
804        self.preamble = None;
805        self
806    }
807
808    /// Adds a message to the chat history for the completion request.
809    pub fn message(mut self, message: Message) -> Self {
810        self.chat_history.push(message);
811
812        self
813    }
814
815    /// Adds a list of messages to the chat history for the completion request.
816    pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
817        self.chat_history.extend(messages);
818
819        self
820    }
821
822    /// Adds a document to the completion request.
823    pub fn document(mut self, document: Document) -> Self {
824        self.documents.push(document);
825        self
826    }
827
828    /// Adds a list of documents to the completion request.
829    pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
830        documents
831            .into_iter()
832            .fold(self, |builder, doc| builder.document(doc))
833    }
834
835    /// Adds a tool to the completion request.
836    pub fn tool(mut self, tool: ToolDefinition) -> Self {
837        self.tools.push(tool);
838        self
839    }
840
841    /// Adds a list of tools to the completion request.
842    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
843        tools
844            .into_iter()
845            .fold(self, |builder, tool| builder.tool(tool))
846    }
847
848    /// Adds a provider-hosted tool to the completion request.
849    pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
850        self.provider_tools.push(tool);
851        self
852    }
853
854    /// Adds provider-hosted tools to the completion request.
855    pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
856        tools
857            .into_iter()
858            .fold(self, |builder, tool| builder.provider_tool(tool))
859    }
860
861    /// Adds additional parameters to the completion request.
862    /// This can be used to set additional provider-specific parameters. For example,
863    /// Cohere's completion models accept a `connectors` parameter that can be used to
864    /// specify the data connectors used by Cohere when executing the completion
865    /// (see `examples/cohere_connectors.rs`).
866    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
867        match self.additional_params {
868            Some(params) => {
869                self.additional_params = Some(json_utils::merge(params, additional_params));
870            }
871            None => {
872                self.additional_params = Some(additional_params);
873            }
874        }
875        self
876    }
877
878    /// Sets the additional parameters for the completion request.
879    /// This can be used to set additional provider-specific parameters. For example,
880    /// Cohere's completion models accept a `connectors` parameter that can be used to
881    /// specify the data connectors used by Cohere when executing the completion
882    /// (see `examples/cohere_connectors.rs`).
883    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
884        self.additional_params = additional_params;
885        self
886    }
887
888    /// Sets the temperature for the completion request.
889    pub fn temperature(mut self, temperature: f64) -> Self {
890        self.temperature = Some(temperature);
891        self
892    }
893
894    /// Sets the temperature for the completion request.
895    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
896        self.temperature = temperature;
897        self
898    }
899
900    /// Sets the max tokens for the completion request.
901    /// Note: This is required if using Anthropic
902    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
903        self.max_tokens = Some(max_tokens);
904        self
905    }
906
907    /// Sets the max tokens for the completion request.
908    /// Note: This is required if using Anthropic
909    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
910        self.max_tokens = max_tokens;
911        self
912    }
913
914    /// Sets the thing.
915    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
916        self.tool_choice = Some(tool_choice);
917        self
918    }
919
920    /// Sets the output schema for structured output. When set, providers that support
921    /// native structured outputs will constrain the model's response to match this schema.
922    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
923    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
924    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
925    /// to still be able to leverage structured outputs.
926    pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
927        self.output_schema = Some(schema);
928        self
929    }
930
931    /// Sets the output schema for structured output from an optional value.
932    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
933    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
934    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
935    /// to still be able to leverage structured outputs.
936    pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
937        self.output_schema = schema;
938        self
939    }
940
941    /// Builds the completion request.
942    pub fn build(self) -> CompletionRequest {
943        // Build the final message list, prepending preamble if present
944        let mut chat_history = self.chat_history;
945        if let Some(preamble) = self.preamble {
946            chat_history.insert(0, Message::system(preamble));
947        }
948        chat_history.push(self.prompt);
949
950        let chat_history =
951            OneOrMany::many(chat_history).expect("There will always be at least the prompt");
952        let additional_params = merge_provider_tools_into_additional_params(
953            self.additional_params,
954            self.provider_tools,
955        );
956
957        CompletionRequest {
958            model: self.request_model,
959            preamble: None,
960            chat_history,
961            documents: self.documents,
962            tools: self.tools,
963            temperature: self.temperature,
964            max_tokens: self.max_tokens,
965            tool_choice: self.tool_choice,
966            additional_params,
967            output_schema: self.output_schema,
968        }
969    }
970
971    /// Sends the completion request to the completion model provider and returns the completion response.
972    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
973        let model = self.model.clone();
974        model.completion(self.build()).await
975    }
976
977    /// Stream the completion request
978    pub async fn stream<'a>(
979        self,
980    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
981    where
982        <M as CompletionModel>::StreamingResponse: 'a,
983        Self: 'a,
984    {
985        let model = self.model.clone();
986        model.stream(self.build()).await
987    }
988}
989
990#[cfg(test)]
991mod tests {
992
993    use super::*;
994    use crate::streaming::StreamingCompletionResponse;
995    use serde::{Deserialize, Serialize};
996
997    #[derive(Clone)]
998    struct DummyModel;
999
1000    #[derive(Clone, Debug, Serialize, Deserialize)]
1001    struct DummyStreamingResponse;
1002
1003    impl GetTokenUsage for DummyStreamingResponse {
1004        fn token_usage(&self) -> Option<Usage> {
1005            None
1006        }
1007    }
1008
1009    impl CompletionModel for DummyModel {
1010        type Response = serde_json::Value;
1011        type StreamingResponse = DummyStreamingResponse;
1012        type Client = ();
1013
1014        fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
1015            Self
1016        }
1017
1018        async fn completion(
1019            &self,
1020            _request: CompletionRequest,
1021        ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1022            Err(CompletionError::ProviderError(
1023                "dummy completion model".to_string(),
1024            ))
1025        }
1026
1027        async fn stream(
1028            &self,
1029            _request: CompletionRequest,
1030        ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1031            Err(CompletionError::ProviderError(
1032                "dummy completion model".to_string(),
1033            ))
1034        }
1035    }
1036
1037    #[test]
1038    fn test_document_display_without_metadata() {
1039        let doc = Document {
1040            id: "123".to_string(),
1041            text: "This is a test document.".to_string(),
1042            additional_props: HashMap::new(),
1043        };
1044
1045        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
1046        assert_eq!(format!("{doc}"), expected);
1047    }
1048
1049    #[test]
1050    fn test_document_display_with_metadata() {
1051        let mut additional_props = HashMap::new();
1052        additional_props.insert("author".to_string(), "John Doe".to_string());
1053        additional_props.insert("length".to_string(), "42".to_string());
1054
1055        let doc = Document {
1056            id: "123".to_string(),
1057            text: "This is a test document.".to_string(),
1058            additional_props,
1059        };
1060
1061        let expected = concat!(
1062            "<file id: 123>\n",
1063            "<metadata author: \"John Doe\" length: \"42\" />\n",
1064            "This is a test document.\n",
1065            "</file>\n"
1066        );
1067        assert_eq!(format!("{doc}"), expected);
1068    }
1069
1070    #[test]
1071    fn test_normalize_documents_with_documents() {
1072        let doc1 = Document {
1073            id: "doc1".to_string(),
1074            text: "Document 1 text.".to_string(),
1075            additional_props: HashMap::new(),
1076        };
1077
1078        let doc2 = Document {
1079            id: "doc2".to_string(),
1080            text: "Document 2 text.".to_string(),
1081            additional_props: HashMap::new(),
1082        };
1083
1084        let request = CompletionRequest {
1085            model: None,
1086            preamble: None,
1087            chat_history: OneOrMany::one("What is the capital of France?".into()),
1088            documents: vec![doc1, doc2],
1089            tools: Vec::new(),
1090            temperature: None,
1091            max_tokens: None,
1092            tool_choice: None,
1093            additional_params: None,
1094            output_schema: None,
1095        };
1096
1097        let expected = Message::User {
1098            content: OneOrMany::many(vec![
1099                UserContent::document(
1100                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1101                    Some(DocumentMediaType::TXT),
1102                ),
1103                UserContent::document(
1104                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1105                    Some(DocumentMediaType::TXT),
1106                ),
1107            ])
1108            .expect("There will be at least one document"),
1109        };
1110
1111        assert_eq!(request.normalized_documents(), Some(expected));
1112    }
1113
1114    #[test]
1115    fn test_normalize_documents_without_documents() {
1116        let request = CompletionRequest {
1117            model: None,
1118            preamble: None,
1119            chat_history: OneOrMany::one("What is the capital of France?".into()),
1120            documents: Vec::new(),
1121            tools: Vec::new(),
1122            temperature: None,
1123            max_tokens: None,
1124            tool_choice: None,
1125            additional_params: None,
1126            output_schema: None,
1127        };
1128
1129        assert_eq!(request.normalized_documents(), None);
1130    }
1131
1132    #[test]
1133    fn preamble_builder_funnels_to_system_message() {
1134        let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1135            .preamble("System prompt".to_string())
1136            .message(Message::user("History"))
1137            .build();
1138
1139        assert_eq!(request.preamble, None);
1140
1141        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1142        assert_eq!(history.len(), 3);
1143        assert!(matches!(
1144            &history[0],
1145            Message::System { content } if content == "System prompt"
1146        ));
1147        assert!(matches!(&history[1], Message::User { .. }));
1148        assert!(matches!(&history[2], Message::User { .. }));
1149    }
1150
1151    #[test]
1152    fn without_preamble_removes_legacy_preamble_injection() {
1153        let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1154            .preamble("System prompt".to_string())
1155            .without_preamble()
1156            .build();
1157
1158        assert_eq!(request.preamble, None);
1159        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1160        assert_eq!(history.len(), 1);
1161        assert!(matches!(&history[0], Message::User { .. }));
1162    }
1163}