Skip to main content

rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76    json_utils,
77    message::{Message, UserContent},
78    tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87// Errors
88#[derive(Debug, Error)]
89pub enum CompletionError {
90    /// Http error (e.g.: connection error, timeout, etc.)
91    #[error("HttpError: {0}")]
92    HttpError(#[from] http_client::Error),
93
94    /// Json error (e.g.: serialization, deserialization)
95    #[error("JsonError: {0}")]
96    JsonError(#[from] serde_json::Error),
97
98    /// Url error (e.g.: invalid URL)
99    #[error("UrlError: {0}")]
100    UrlError(#[from] url::ParseError),
101
102    #[cfg(not(target_family = "wasm"))]
103    /// Error building the completion request
104    #[error("RequestError: {0}")]
105    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107    #[cfg(target_family = "wasm")]
108    /// Error building the completion request
109    #[error("RequestError: {0}")]
110    RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112    /// Error parsing the completion response
113    #[error("ResponseError: {0}")]
114    ResponseError(String),
115
116    /// Error returned by the completion model provider
117    #[error("ProviderError: {0}")]
118    ProviderError(String),
119}
120
121/// Prompt errors
122#[derive(Debug, Error)]
123pub enum PromptError {
124    /// Something went wrong with the completion
125    #[error("CompletionError: {0}")]
126    CompletionError(#[from] CompletionError),
127
128    /// There was an error while using a tool
129    #[error("ToolCallError: {0}")]
130    ToolError(#[from] ToolSetError),
131
132    /// There was an issue while executing a tool on a tool server
133    #[error("ToolServerError: {0}")]
134    ToolServerError(#[from] ToolServerError),
135
136    /// The LLM tried to call too many tools during a multi-turn conversation.
137    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
138    /// or increase the amount of turns given in `.multi_turn()`.
139    #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
140    MaxTurnsError {
141        max_turns: usize,
142        chat_history: Box<Vec<Message>>,
143        prompt: Box<Message>,
144    },
145
146    /// A prompting loop was cancelled.
147    #[error("PromptCancelled: {reason}")]
148    PromptCancelled {
149        chat_history: Box<Vec<Message>>,
150        reason: String,
151    },
152}
153
154impl PromptError {
155    pub(crate) fn prompt_cancelled(chat_history: Vec<Message>, reason: impl Into<String>) -> Self {
156        Self::PromptCancelled {
157            chat_history: Box::new(chat_history),
158            reason: reason.into(),
159        }
160    }
161}
162
163/// Errors that can occur when using typed structured output via [`TypedPrompt::prompt_typed`].
164#[derive(Debug, Error)]
165pub enum StructuredOutputError {
166    /// An error occurred during the prompt execution.
167    #[error("PromptError: {0}")]
168    PromptError(#[from] PromptError),
169
170    /// Failed to deserialize the model's response into the target type.
171    #[error("DeserializationError: {0}")]
172    DeserializationError(#[from] serde_json::Error),
173
174    /// The model returned an empty response.
175    #[error("EmptyResponse: model returned no content")]
176    EmptyResponse,
177}
178
179#[derive(Clone, Debug, Deserialize, Serialize)]
180pub struct Document {
181    pub id: String,
182    pub text: String,
183    #[serde(flatten)]
184    pub additional_props: HashMap<String, String>,
185}
186
187impl std::fmt::Display for Document {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        write!(
190            f,
191            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
192            self.id,
193            if self.additional_props.is_empty() {
194                self.text.clone()
195            } else {
196                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
197                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
198                let metadata = sorted_props
199                    .iter()
200                    .map(|(k, v)| format!("{k}: {v:?}"))
201                    .collect::<Vec<_>>()
202                    .join(" ");
203                format!("<metadata {} />\n{}", metadata, self.text)
204            }
205        )
206    }
207}
208
209#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
210pub struct ToolDefinition {
211    pub name: String,
212    pub description: String,
213    pub parameters: serde_json::Value,
214}
215
216/// Provider-native tool definition.
217///
218/// Stored under `additional_params.tools` and forwarded by providers that support
219/// provider-managed tools.
220#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
221pub struct ProviderToolDefinition {
222    /// Tool type/kind name as expected by the target provider (for example `web_search`).
223    #[serde(rename = "type")]
224    pub kind: String,
225    /// Additional provider-specific configuration for this hosted tool.
226    #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
227    pub config: serde_json::Map<String, serde_json::Value>,
228}
229
230impl ProviderToolDefinition {
231    /// Creates a provider-hosted tool definition by type.
232    pub fn new(kind: impl Into<String>) -> Self {
233        Self {
234            kind: kind.into(),
235            config: serde_json::Map::new(),
236        }
237    }
238
239    /// Adds a provider-specific configuration key/value.
240    pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
241        self.config.insert(key.into(), value);
242        self
243    }
244}
245
246// ================================================================
247// Implementations
248// ================================================================
249/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
250pub trait Prompt: WasmCompatSend + WasmCompatSync {
251    /// Send a simple prompt to the underlying completion model.
252    ///
253    /// If the completion model's response is a message, then it is returned as a string.
254    ///
255    /// If the completion model's response is a tool call, then the tool is called and
256    /// the result is returned as a string.
257    ///
258    /// If the tool does not exist, or the tool call fails, then an error is returned.
259    fn prompt(
260        &self,
261        prompt: impl Into<Message> + WasmCompatSend,
262    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
263}
264
265/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
266pub trait Chat: WasmCompatSend + WasmCompatSync {
267    /// Send a prompt with optional chat history to the underlying completion model.
268    ///
269    /// If the completion model's response is a message, then it is returned as a string.
270    ///
271    /// If the completion model's response is a tool call, then the tool is called and the result
272    /// is returned as a string.
273    ///
274    /// If the tool does not exist, or the tool call fails, then an error is returned.
275    fn chat(
276        &self,
277        prompt: impl Into<Message> + WasmCompatSend,
278        chat_history: Vec<Message>,
279    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
280}
281
282/// Trait defining a high-level typed prompt interface for structured output.
283///
284/// This trait provides an ergonomic way to get typed responses from an LLM by automatically
285/// generating a JSON schema from the target type and deserializing the response.
286///
287/// # Example
288/// ```rust,ignore
289/// use rig::prelude::*;
290/// use schemars::JsonSchema;
291/// use serde::Deserialize;
292///
293/// #[derive(Debug, Deserialize, JsonSchema)]
294/// struct WeatherForecast {
295///     city: String,
296///     temperature_f: f64,
297///     conditions: String,
298/// }
299///
300/// let agent = client.agent("gpt-4o").build();
301/// let forecast: WeatherForecast = agent
302///     .prompt_typed("What's the weather in NYC?")
303///     .await?;
304/// ```
305pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
306    /// The type of the typed prompt request returned by `prompt_typed`.
307    type TypedRequest<'a, T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
308    where
309        Self: 'a,
310        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
311
312    /// Send a prompt and receive a typed structured response.
313    ///
314    /// The JSON schema for `T` is automatically generated and sent to the provider.
315    /// Providers that support native structured outputs will constrain the model's
316    /// response to match this schema.
317    ///
318    /// # Type Parameters
319    /// * `T` - The target type to deserialize the response into. Must implement
320    ///   `JsonSchema` (for schema generation), `DeserializeOwned` (for deserialization),
321    ///   and `WasmCompatSend` (for async compatibility).
322    ///
323    /// # Example
324    /// ```rust,ignore
325    /// // Type can be inferred
326    /// let forecast: WeatherForecast = agent.prompt_typed("What's the weather?").await?;
327    ///
328    /// // Or specified explicitly with turbofish
329    /// let forecast = agent.prompt_typed::<WeatherForecast>("What's the weather?").await?;
330    /// ```
331    fn prompt_typed<T>(
332        &self,
333        prompt: impl Into<Message> + WasmCompatSend,
334    ) -> Self::TypedRequest<'_, T>
335    where
336        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
337}
338
339/// Trait defining a low-level LLM completion interface
340pub trait Completion<M: CompletionModel> {
341    /// Generates a completion request builder for the given `prompt` and `chat_history`.
342    /// This function is meant to be called by the user to further customize the
343    /// request at prompt time before sending it.
344    ///
345    /// ❗IMPORTANT: The type that implements this trait might have already
346    /// populated fields in the builder (the exact fields depend on the type).
347    /// For fields that have already been set by the model, calling the corresponding
348    /// method on the builder will overwrite the value set by the model.
349    ///
350    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
351    /// contain the `preamble` provided when creating the agent.
352    fn completion(
353        &self,
354        prompt: impl Into<Message> + WasmCompatSend,
355        chat_history: Vec<Message>,
356    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
357    + WasmCompatSend;
358}
359
360/// General completion response struct that contains the high-level completion choice
361/// and the raw response. The completion choice contains one or more assistant content.
362#[derive(Debug)]
363pub struct CompletionResponse<T> {
364    /// The completion choice (represented by one or more assistant message content)
365    /// returned by the completion model provider
366    pub choice: OneOrMany<AssistantContent>,
367    /// Tokens used during prompting and responding
368    pub usage: Usage,
369    /// The raw response returned by the completion model provider
370    pub raw_response: T,
371    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
372    /// Used to pair reasoning input items with their output items in multi-turn.
373    pub message_id: Option<String>,
374}
375
376/// A trait for grabbing the token usage of a completion response.
377///
378/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
379pub trait GetTokenUsage {
380    fn token_usage(&self) -> Option<crate::completion::Usage>;
381}
382
383impl GetTokenUsage for () {
384    fn token_usage(&self) -> Option<crate::completion::Usage> {
385        None
386    }
387}
388
389impl<T> GetTokenUsage for Option<T>
390where
391    T: GetTokenUsage,
392{
393    fn token_usage(&self) -> Option<crate::completion::Usage> {
394        if let Some(usage) = self {
395            usage.token_usage()
396        } else {
397            None
398        }
399    }
400}
401
402/// Struct representing the token usage for a completion request.
403/// If tokens used are `0`, then the provider failed to supply token usage metrics.
404#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
405pub struct Usage {
406    /// The number of input ("prompt") tokens used in a given request.
407    pub input_tokens: u64,
408    /// The number of output ("completion") tokens used in a given request.
409    pub output_tokens: u64,
410    /// We store this separately as some providers may only report one number
411    pub total_tokens: u64,
412    /// The number of cached input tokens (from prompt caching). 0 if not reported by provider.
413    pub cached_input_tokens: u64,
414}
415
416impl Usage {
417    /// Creates a new instance of `Usage`.
418    pub fn new() -> Self {
419        Self {
420            input_tokens: 0,
421            output_tokens: 0,
422            total_tokens: 0,
423            cached_input_tokens: 0,
424        }
425    }
426}
427
428impl Default for Usage {
429    fn default() -> Self {
430        Self::new()
431    }
432}
433
434impl Add for Usage {
435    type Output = Self;
436
437    fn add(self, other: Self) -> Self::Output {
438        Self {
439            input_tokens: self.input_tokens + other.input_tokens,
440            output_tokens: self.output_tokens + other.output_tokens,
441            total_tokens: self.total_tokens + other.total_tokens,
442            cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
443        }
444    }
445}
446
447impl AddAssign for Usage {
448    fn add_assign(&mut self, other: Self) {
449        self.input_tokens += other.input_tokens;
450        self.output_tokens += other.output_tokens;
451        self.total_tokens += other.total_tokens;
452        self.cached_input_tokens += other.cached_input_tokens;
453    }
454}
455
456/// Trait defining a completion model that can be used to generate completion responses.
457/// This trait is meant to be implemented by the user to define a custom completion model,
458/// either from a third party provider (e.g.: OpenAI) or a local model.
459pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
460    /// The raw response type returned by the underlying completion model.
461    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
462    /// The raw response type returned by the underlying completion model when streaming.
463    type StreamingResponse: Clone
464        + Unpin
465        + WasmCompatSend
466        + WasmCompatSync
467        + Serialize
468        + DeserializeOwned
469        + GetTokenUsage;
470
471    type Client;
472
473    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
474
475    /// Generates a completion response for the given completion request.
476    fn completion(
477        &self,
478        request: CompletionRequest,
479    ) -> impl std::future::Future<
480        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
481    > + WasmCompatSend;
482
483    fn stream(
484        &self,
485        request: CompletionRequest,
486    ) -> impl std::future::Future<
487        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
488    > + WasmCompatSend;
489
490    /// Generates a completion request builder for the given `prompt`.
491    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
492        CompletionRequestBuilder::new(self.clone(), prompt)
493    }
494}
495
496#[allow(deprecated)]
497#[deprecated(
498    since = "0.25.0",
499    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
500)]
501pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
502    fn completion(
503        &self,
504        request: CompletionRequest,
505    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
506
507    fn stream(
508        &self,
509        request: CompletionRequest,
510    ) -> WasmBoxedFuture<
511        '_,
512        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
513    >;
514
515    fn completion_request(
516        &self,
517        prompt: Message,
518    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
519}
520
521#[allow(deprecated)]
522impl<T, R> CompletionModelDyn for T
523where
524    T: CompletionModel<StreamingResponse = R>,
525    R: Clone + Unpin + GetTokenUsage + 'static,
526{
527    fn completion(
528        &self,
529        request: CompletionRequest,
530    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
531        Box::pin(async move {
532            self.completion(request)
533                .await
534                .map(|resp| CompletionResponse {
535                    choice: resp.choice,
536                    usage: resp.usage,
537                    raw_response: (),
538                    message_id: resp.message_id,
539                })
540        })
541    }
542
543    fn stream(
544        &self,
545        request: CompletionRequest,
546    ) -> WasmBoxedFuture<
547        '_,
548        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
549    > {
550        Box::pin(async move {
551            let resp = self.stream(request).await?;
552            let inner = resp.inner;
553
554            let stream = streaming::StreamingResultDyn {
555                inner: Box::pin(inner),
556            };
557
558            Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
559        })
560    }
561
562    /// Generates a completion request builder for the given `prompt`.
563    fn completion_request(
564        &self,
565        prompt: Message,
566    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
567        CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
568    }
569}
570
571/// Struct representing a general completion request that can be sent to a completion model provider.
572#[derive(Debug, Clone)]
573pub struct CompletionRequest {
574    /// Optional model override for this request.
575    pub model: Option<String>,
576    /// Legacy preamble field preserved for backwards compatibility.
577    ///
578    /// New code should prefer a leading [`Message::System`]
579    /// in `chat_history` as the canonical representation of system instructions.
580    pub preamble: Option<String>,
581    /// The chat history to be sent to the completion model provider.
582    /// The very last message will always be the prompt (hence why there is *always* one)
583    pub chat_history: OneOrMany<Message>,
584    /// The documents to be sent to the completion model provider
585    pub documents: Vec<Document>,
586    /// The tools to be sent to the completion model provider
587    pub tools: Vec<ToolDefinition>,
588    /// The temperature to be sent to the completion model provider
589    pub temperature: Option<f64>,
590    /// The max tokens to be sent to the completion model provider
591    pub max_tokens: Option<u64>,
592    /// Whether tools are required to be used by the model provider or not before providing a response.
593    pub tool_choice: Option<ToolChoice>,
594    /// Additional provider-specific parameters to be sent to the completion model provider
595    pub additional_params: Option<serde_json::Value>,
596    /// Optional JSON Schema for structured output. When set, providers that support
597    /// native structured outputs will constrain the model's response to match this schema.
598    pub output_schema: Option<schemars::Schema>,
599}
600
601impl CompletionRequest {
602    /// Extracts a name from the output schema's `"title"` field, falling back to `"response_schema"`.
603    /// Useful for providers that require a name alongside the JSON Schema (e.g., OpenAI).
604    pub fn output_schema_name(&self) -> Option<String> {
605        self.output_schema.as_ref().map(|schema| {
606            schema
607                .as_object()
608                .and_then(|o| o.get("title"))
609                .and_then(|v| v.as_str())
610                .unwrap_or("response_schema")
611                .to_string()
612        })
613    }
614
615    /// Returns documents normalized into a message (if any).
616    /// Most providers do not accept documents directly as input, so it needs to convert into a
617    ///  `Message` so that it can be incorporated into `chat_history` as a
618    pub fn normalized_documents(&self) -> Option<Message> {
619        if self.documents.is_empty() {
620            return None;
621        }
622
623        // Most providers will convert documents into a text unless it can handle document messages.
624        // We use `UserContent::document` for those who handle it directly!
625        let messages = self
626            .documents
627            .iter()
628            .map(|doc| {
629                UserContent::document(
630                    doc.to_string(),
631                    // In the future, we can customize `Document` to pass these extra types through.
632                    // Most providers ditch these but they might want to use them.
633                    Some(DocumentMediaType::TXT),
634                )
635            })
636            .collect::<Vec<_>>();
637
638        Some(Message::User {
639            content: OneOrMany::many(messages).expect("There will be atleast one document"),
640        })
641    }
642
643    /// Adds a provider-hosted tool by storing it in `additional_params.tools`.
644    pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
645        self.additional_params =
646            merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
647        self
648    }
649
650    /// Adds provider-hosted tools by storing them in `additional_params.tools`.
651    pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
652        self.additional_params =
653            merge_provider_tools_into_additional_params(self.additional_params, tools);
654        self
655    }
656}
657
658fn merge_provider_tools_into_additional_params(
659    additional_params: Option<serde_json::Value>,
660    provider_tools: Vec<ProviderToolDefinition>,
661) -> Option<serde_json::Value> {
662    if provider_tools.is_empty() {
663        return additional_params;
664    }
665
666    let mut provider_tools_json = provider_tools
667        .into_iter()
668        .map(|ProviderToolDefinition { kind, mut config }| {
669            // Force the provider tool type from the strongly-typed field.
670            config.insert("type".to_string(), serde_json::Value::String(kind));
671            serde_json::Value::Object(config)
672        })
673        .collect::<Vec<_>>();
674
675    let mut params_map = match additional_params {
676        Some(serde_json::Value::Object(map)) => map,
677        Some(serde_json::Value::Bool(stream)) => {
678            let mut map = serde_json::Map::new();
679            map.insert("stream".to_string(), serde_json::Value::Bool(stream));
680            map
681        }
682        _ => serde_json::Map::new(),
683    };
684
685    let mut merged_tools = match params_map.remove("tools") {
686        Some(serde_json::Value::Array(existing)) => existing,
687        _ => Vec::new(),
688    };
689    merged_tools.append(&mut provider_tools_json);
690    params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
691    Some(serde_json::Value::Object(params_map))
692}
693
694/// Builder struct for constructing a completion request.
695///
696/// Example usage:
697/// ```rust
698/// use rig::{
699///     providers::openai::{Client, self},
700///     completion::CompletionRequestBuilder,
701/// };
702///
703/// let openai = Client::new("your-openai-api-key");
704/// let model = openai.completion_model(openai::GPT_4O).build();
705///
706/// // Create the completion request and execute it separately
707/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
708///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
709///     .temperature(0.5)
710///     .build();
711///
712/// let response = model.completion(request)
713///     .await
714///     .expect("Failed to get completion response");
715/// ```
716///
717/// Alternatively, you can execute the completion request directly from the builder:
718/// ```rust
719/// use rig::{
720///     providers::openai::{Client, self},
721///     completion::CompletionRequestBuilder,
722/// };
723///
724/// let openai = Client::new("your-openai-api-key");
725/// let model = openai.completion_model(openai::GPT_4O).build();
726///
727/// // Create the completion request and execute it directly
728/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
729///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
730///     .temperature(0.5)
731///     .send()
732///     .await
733///     .expect("Failed to get completion response");
734/// ```
735///
736/// Note: It is usually unnecessary to create a completion request builder directly.
737/// Instead, use the [CompletionModel::completion_request] method.
738pub struct CompletionRequestBuilder<M: CompletionModel> {
739    model: M,
740    prompt: Message,
741    request_model: Option<String>,
742    preamble: Option<String>,
743    chat_history: Vec<Message>,
744    documents: Vec<Document>,
745    tools: Vec<ToolDefinition>,
746    provider_tools: Vec<ProviderToolDefinition>,
747    temperature: Option<f64>,
748    max_tokens: Option<u64>,
749    tool_choice: Option<ToolChoice>,
750    additional_params: Option<serde_json::Value>,
751    output_schema: Option<schemars::Schema>,
752}
753
754impl<M: CompletionModel> CompletionRequestBuilder<M> {
755    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
756        Self {
757            model,
758            prompt: prompt.into(),
759            request_model: None,
760            preamble: None,
761            chat_history: Vec::new(),
762            documents: Vec::new(),
763            tools: Vec::new(),
764            provider_tools: Vec::new(),
765            temperature: None,
766            max_tokens: None,
767            tool_choice: None,
768            additional_params: None,
769            output_schema: None,
770        }
771    }
772
773    /// Sets the preamble for the completion request.
774    pub fn preamble(mut self, preamble: String) -> Self {
775        // Legacy public API: funnel preamble into canonical system messages at build-time.
776        self.preamble = Some(preamble);
777        self
778    }
779
780    /// Overrides the model used for this request.
781    pub fn model(mut self, model: impl Into<String>) -> Self {
782        self.request_model = Some(model.into());
783        self
784    }
785
786    /// Overrides the model used for this request.
787    pub fn model_opt(mut self, model: Option<String>) -> Self {
788        self.request_model = model;
789        self
790    }
791
792    pub fn without_preamble(mut self) -> Self {
793        self.preamble = None;
794        self
795    }
796
797    /// Adds a message to the chat history for the completion request.
798    pub fn message(mut self, message: Message) -> Self {
799        self.chat_history.push(message);
800        self
801    }
802
803    /// Adds a list of messages to the chat history for the completion request.
804    pub fn messages(self, messages: Vec<Message>) -> Self {
805        messages
806            .into_iter()
807            .fold(self, |builder, msg| builder.message(msg))
808    }
809
810    /// Adds a document to the completion request.
811    pub fn document(mut self, document: Document) -> Self {
812        self.documents.push(document);
813        self
814    }
815
816    /// Adds a list of documents to the completion request.
817    pub fn documents(self, documents: Vec<Document>) -> Self {
818        documents
819            .into_iter()
820            .fold(self, |builder, doc| builder.document(doc))
821    }
822
823    /// Adds a tool to the completion request.
824    pub fn tool(mut self, tool: ToolDefinition) -> Self {
825        self.tools.push(tool);
826        self
827    }
828
829    /// Adds a list of tools to the completion request.
830    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
831        tools
832            .into_iter()
833            .fold(self, |builder, tool| builder.tool(tool))
834    }
835
836    /// Adds a provider-hosted tool to the completion request.
837    pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
838        self.provider_tools.push(tool);
839        self
840    }
841
842    /// Adds provider-hosted tools to the completion request.
843    pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
844        tools
845            .into_iter()
846            .fold(self, |builder, tool| builder.provider_tool(tool))
847    }
848
849    /// Adds additional parameters to the completion request.
850    /// This can be used to set additional provider-specific parameters. For example,
851    /// Cohere's completion models accept a `connectors` parameter that can be used to
852    /// specify the data connectors used by Cohere when executing the completion
853    /// (see `examples/cohere_connectors.rs`).
854    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
855        match self.additional_params {
856            Some(params) => {
857                self.additional_params = Some(json_utils::merge(params, additional_params));
858            }
859            None => {
860                self.additional_params = Some(additional_params);
861            }
862        }
863        self
864    }
865
866    /// Sets the additional parameters for the completion request.
867    /// This can be used to set additional provider-specific parameters. For example,
868    /// Cohere's completion models accept a `connectors` parameter that can be used to
869    /// specify the data connectors used by Cohere when executing the completion
870    /// (see `examples/cohere_connectors.rs`).
871    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
872        self.additional_params = additional_params;
873        self
874    }
875
876    /// Sets the temperature for the completion request.
877    pub fn temperature(mut self, temperature: f64) -> Self {
878        self.temperature = Some(temperature);
879        self
880    }
881
882    /// Sets the temperature for the completion request.
883    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
884        self.temperature = temperature;
885        self
886    }
887
888    /// Sets the max tokens for the completion request.
889    /// Note: This is required if using Anthropic
890    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
891        self.max_tokens = Some(max_tokens);
892        self
893    }
894
895    /// Sets the max tokens for the completion request.
896    /// Note: This is required if using Anthropic
897    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
898        self.max_tokens = max_tokens;
899        self
900    }
901
902    /// Sets the thing.
903    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
904        self.tool_choice = Some(tool_choice);
905        self
906    }
907
908    /// Sets the output schema for structured output. When set, providers that support
909    /// native structured outputs will constrain the model's response to match this schema.
910    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
911    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
912    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
913    /// to still be able to leverage structured outputs.
914    pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
915        self.output_schema = Some(schema);
916        self
917    }
918
919    /// Sets the output schema for structured output from an optional value.
920    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
921    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
922    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
923    /// to still be able to leverage structured outputs.
924    pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
925        self.output_schema = schema;
926        self
927    }
928
929    /// Builds the completion request.
930    pub fn build(self) -> CompletionRequest {
931        let mut chat_history = self.chat_history;
932        if let Some(preamble) = self.preamble {
933            chat_history.insert(0, Message::system(preamble));
934        }
935        let chat_history = OneOrMany::many([chat_history, vec![self.prompt]].concat())
936            .expect("There will always be atleast the prompt");
937        let additional_params = merge_provider_tools_into_additional_params(
938            self.additional_params,
939            self.provider_tools,
940        );
941
942        CompletionRequest {
943            model: self.request_model,
944            preamble: None,
945            chat_history,
946            documents: self.documents,
947            tools: self.tools,
948            temperature: self.temperature,
949            max_tokens: self.max_tokens,
950            tool_choice: self.tool_choice,
951            additional_params,
952            output_schema: self.output_schema,
953        }
954    }
955
956    /// Sends the completion request to the completion model provider and returns the completion response.
957    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
958        let model = self.model.clone();
959        model.completion(self.build()).await
960    }
961
962    /// Stream the completion request
963    pub async fn stream<'a>(
964        self,
965    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
966    where
967        <M as CompletionModel>::StreamingResponse: 'a,
968        Self: 'a,
969    {
970        let model = self.model.clone();
971        model.stream(self.build()).await
972    }
973}
974
975#[cfg(test)]
976mod tests {
977
978    use super::*;
979    use crate::streaming::StreamingCompletionResponse;
980    use serde::{Deserialize, Serialize};
981
982    #[derive(Clone)]
983    struct DummyModel;
984
985    #[derive(Clone, Debug, Serialize, Deserialize)]
986    struct DummyStreamingResponse;
987
988    impl GetTokenUsage for DummyStreamingResponse {
989        fn token_usage(&self) -> Option<Usage> {
990            None
991        }
992    }
993
994    impl CompletionModel for DummyModel {
995        type Response = serde_json::Value;
996        type StreamingResponse = DummyStreamingResponse;
997        type Client = ();
998
999        fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
1000            Self
1001        }
1002
1003        async fn completion(
1004            &self,
1005            _request: CompletionRequest,
1006        ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
1007            Err(CompletionError::ProviderError(
1008                "dummy completion model".to_string(),
1009            ))
1010        }
1011
1012        async fn stream(
1013            &self,
1014            _request: CompletionRequest,
1015        ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
1016            Err(CompletionError::ProviderError(
1017                "dummy completion model".to_string(),
1018            ))
1019        }
1020    }
1021
1022    #[test]
1023    fn test_document_display_without_metadata() {
1024        let doc = Document {
1025            id: "123".to_string(),
1026            text: "This is a test document.".to_string(),
1027            additional_props: HashMap::new(),
1028        };
1029
1030        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
1031        assert_eq!(format!("{doc}"), expected);
1032    }
1033
1034    #[test]
1035    fn test_document_display_with_metadata() {
1036        let mut additional_props = HashMap::new();
1037        additional_props.insert("author".to_string(), "John Doe".to_string());
1038        additional_props.insert("length".to_string(), "42".to_string());
1039
1040        let doc = Document {
1041            id: "123".to_string(),
1042            text: "This is a test document.".to_string(),
1043            additional_props,
1044        };
1045
1046        let expected = concat!(
1047            "<file id: 123>\n",
1048            "<metadata author: \"John Doe\" length: \"42\" />\n",
1049            "This is a test document.\n",
1050            "</file>\n"
1051        );
1052        assert_eq!(format!("{doc}"), expected);
1053    }
1054
1055    #[test]
1056    fn test_normalize_documents_with_documents() {
1057        let doc1 = Document {
1058            id: "doc1".to_string(),
1059            text: "Document 1 text.".to_string(),
1060            additional_props: HashMap::new(),
1061        };
1062
1063        let doc2 = Document {
1064            id: "doc2".to_string(),
1065            text: "Document 2 text.".to_string(),
1066            additional_props: HashMap::new(),
1067        };
1068
1069        let request = CompletionRequest {
1070            model: None,
1071            preamble: None,
1072            chat_history: OneOrMany::one("What is the capital of France?".into()),
1073            documents: vec![doc1, doc2],
1074            tools: Vec::new(),
1075            temperature: None,
1076            max_tokens: None,
1077            tool_choice: None,
1078            additional_params: None,
1079            output_schema: None,
1080        };
1081
1082        let expected = Message::User {
1083            content: OneOrMany::many(vec![
1084                UserContent::document(
1085                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1086                    Some(DocumentMediaType::TXT),
1087                ),
1088                UserContent::document(
1089                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1090                    Some(DocumentMediaType::TXT),
1091                ),
1092            ])
1093            .expect("There will be at least one document"),
1094        };
1095
1096        assert_eq!(request.normalized_documents(), Some(expected));
1097    }
1098
1099    #[test]
1100    fn test_normalize_documents_without_documents() {
1101        let request = CompletionRequest {
1102            model: None,
1103            preamble: None,
1104            chat_history: OneOrMany::one("What is the capital of France?".into()),
1105            documents: Vec::new(),
1106            tools: Vec::new(),
1107            temperature: None,
1108            max_tokens: None,
1109            tool_choice: None,
1110            additional_params: None,
1111            output_schema: None,
1112        };
1113
1114        assert_eq!(request.normalized_documents(), None);
1115    }
1116
1117    #[test]
1118    fn preamble_builder_funnels_to_system_message() {
1119        let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1120            .preamble("System prompt".to_string())
1121            .message(Message::user("History"))
1122            .build();
1123
1124        assert_eq!(request.preamble, None);
1125
1126        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1127        assert_eq!(history.len(), 3);
1128        assert!(matches!(
1129            &history[0],
1130            Message::System { content } if content == "System prompt"
1131        ));
1132        assert!(matches!(&history[1], Message::User { .. }));
1133        assert!(matches!(&history[2], Message::User { .. }));
1134    }
1135
1136    #[test]
1137    fn without_preamble_removes_legacy_preamble_injection() {
1138        let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1139            .preamble("System prompt".to_string())
1140            .without_preamble()
1141            .build();
1142
1143        assert_eq!(request.preamble, None);
1144        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1145        assert_eq!(history.len(), 1);
1146        assert!(matches!(&history[0], Message::User { .. }));
1147    }
1148}