Skip to main content

rig_core/completion/
request.rs

1//! Completion request, response, and provider trait definitions.
2//!
3//! Most applications use [`Prompt`] or [`Chat`] through
4//! [`Agent`](crate::agent::Agent). Provider integrations implement
5//! [`CompletionModel`] and translate [`CompletionRequest`] into their native HTTP
6//! request format.
7//!
8//! # Low-level request example
9//!
10//! ```no_run
11//! use rig_core::{
12//!     client::{CompletionClient, ProviderClient},
13//!     completion::{AssistantContent, CompletionModel},
14//!     providers::openai,
15//! };
16//!
17//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
18//! let client = openai::Client::from_env()?;
19//! let model = client.completion_model(openai::GPT_5_2);
20//!
21//! let request = model
22//!     .completion_request("Who are you?")
23//!     .preamble("You are a concise assistant.".to_string())
24//!     .temperature(0.5)
25//!     .build();
26//!
27//! let response = model.completion(request).await?;
28//! for item in response.choice {
29//!     if let AssistantContent::Text(text) = item {
30//!         println!("{}", text.text);
31//!     }
32//! }
33//! # Ok(())
34//! # }
35//! ```
36
37use super::message::{AssistantContent, DocumentMediaType};
38use crate::message::ToolChoice;
39use crate::streaming::StreamingCompletionResponse;
40use crate::tool::server::ToolServerError;
41use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
42use crate::{OneOrMany, http_client};
43use crate::{
44    json_utils,
45    message::{Message, UserContent},
46    tool::ToolSetError,
47};
48use serde::de::DeserializeOwned;
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use std::ops::{Add, AddAssign};
52use thiserror::Error;
53
54// Errors
55#[derive(Debug, Error)]
56pub enum CompletionError {
57    /// Http error (e.g.: connection error, timeout, etc.)
58    #[error("HttpError: {0}")]
59    HttpError(#[from] http_client::Error),
60
61    /// Json error (e.g.: serialization, deserialization)
62    #[error("JsonError: {0}")]
63    JsonError(#[from] serde_json::Error),
64
65    /// Url error (e.g.: invalid URL)
66    #[error("UrlError: {0}")]
67    UrlError(#[from] url::ParseError),
68
69    #[cfg(not(target_family = "wasm"))]
70    /// Error building the completion request
71    #[error("RequestError: {0}")]
72    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
73
74    #[cfg(target_family = "wasm")]
75    /// Error building the completion request
76    #[error("RequestError: {0}")]
77    RequestError(#[from] Box<dyn std::error::Error + 'static>),
78
79    /// Error parsing the completion response
80    #[error("ResponseError: {0}")]
81    ResponseError(String),
82
83    /// Error returned by the completion model provider
84    #[error("ProviderError: {0}")]
85    ProviderError(String),
86}
87
88/// Prompt errors
89#[derive(Debug, Error)]
90pub enum PromptError {
91    /// Something went wrong with the completion
92    #[error("CompletionError: {0}")]
93    CompletionError(#[from] CompletionError),
94
95    /// There was an error while using a tool
96    #[error("ToolCallError: {0}")]
97    ToolError(#[from] ToolSetError),
98
99    /// There was an issue while executing a tool on a tool server
100    #[error("ToolServerError: {0}")]
101    ToolServerError(#[from] Box<ToolServerError>),
102
103    /// The LLM tried to call too many tools during a multi-turn conversation.
104    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
105    /// or increase the amount of turns given in `.multi_turn()`.
106    #[error("MaxTurnsError: reached max turns limit: {max_turns}")]
107    MaxTurnsError {
108        max_turns: usize,
109        chat_history: Box<Vec<Message>>,
110        prompt: Box<Message>,
111    },
112
113    /// A prompting loop was cancelled.
114    #[error("PromptCancelled: {reason}")]
115    PromptCancelled {
116        chat_history: Vec<Message>,
117        reason: String,
118    },
119
120    /// The model emitted a structured tool call for a tool Rig did not allow
121    /// for the current turn.
122    #[error(
123        "UnknownToolCall: model attempted to call unknown or disallowed tool `{tool_name}`. Available tools: {available_tools:?}. Allowed tools for this turn: {allowed_tools:?}"
124    )]
125    UnknownToolCall {
126        tool_name: String,
127        available_tools: Vec<String>,
128        allowed_tools: Vec<String>,
129        chat_history: Box<Vec<Message>>,
130    },
131}
132
133/// Surface [`crate::memory::ConversationMemory`] failures through the existing
134/// [`CompletionError::RequestError`] variant so adding memory support does not
135/// require a new top-level [`PromptError`] arm in downstream exhaustive matchers.
136impl From<crate::memory::MemoryError> for PromptError {
137    fn from(err: crate::memory::MemoryError) -> Self {
138        Self::CompletionError(CompletionError::RequestError(Box::new(err)))
139    }
140}
141
142impl PromptError {
143    pub(crate) fn prompt_cancelled(
144        chat_history: impl IntoIterator<Item = Message>,
145        reason: impl Into<String>,
146    ) -> Self {
147        Self::PromptCancelled {
148            chat_history: chat_history.into_iter().collect(),
149            reason: reason.into(),
150        }
151    }
152}
153
154/// Errors that can occur when using typed structured output via [`TypedPrompt::prompt_typed`].
155#[derive(Debug, Error)]
156pub enum StructuredOutputError {
157    /// An error occurred during the prompt execution.
158    #[error("PromptError: {0}")]
159    PromptError(#[from] Box<PromptError>),
160
161    /// Failed to deserialize the model's response into the target type.
162    #[error("DeserializationError: {0}")]
163    DeserializationError(#[from] serde_json::Error),
164
165    /// The model returned an empty response.
166    #[error("EmptyResponse: model returned no content")]
167    EmptyResponse,
168}
169
170#[derive(Clone, Debug, Deserialize, Serialize)]
171pub struct Document {
172    /// Stable document identifier included in the serialized context block.
173    pub id: String,
174    /// Text content passed to the model as retrieval or static context.
175    pub text: String,
176    /// Additional string metadata rendered before the document text.
177    #[serde(flatten)]
178    pub additional_props: HashMap<String, String>,
179}
180
181impl std::fmt::Display for Document {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        write!(
184            f,
185            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
186            self.id,
187            if self.additional_props.is_empty() {
188                self.text.clone()
189            } else {
190                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
191                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
192                let metadata = sorted_props
193                    .iter()
194                    .map(|(k, v)| format!("{k}: {v:?}"))
195                    .collect::<Vec<_>>()
196                    .join(" ");
197                format!("<metadata {} />\n{}", metadata, self.text)
198            }
199        )
200    }
201}
202
203#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
204pub struct ToolDefinition {
205    /// Tool name exposed to the model. It must match the registered tool name.
206    pub name: String,
207    /// Human-readable description sent to the model.
208    pub description: String,
209    /// JSON Schema describing tool arguments.
210    pub parameters: serde_json::Value,
211}
212
213/// Provider-native tool definition.
214///
215/// Stored under `additional_params.tools` and forwarded by providers that support
216/// provider-managed tools.
217#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
218pub struct ProviderToolDefinition {
219    /// Tool type/kind name as expected by the target provider (for example `web_search`).
220    #[serde(rename = "type")]
221    pub kind: String,
222    /// Additional provider-specific configuration for this hosted tool.
223    #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
224    pub config: serde_json::Map<String, serde_json::Value>,
225}
226
227impl ProviderToolDefinition {
228    /// Creates a provider-hosted tool definition by type.
229    pub fn new(kind: impl Into<String>) -> Self {
230        Self {
231            kind: kind.into(),
232            config: serde_json::Map::new(),
233        }
234    }
235
236    /// Adds a provider-specific configuration key/value.
237    pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
238        self.config.insert(key.into(), value);
239        self
240    }
241}
242
243// ================================================================
244// Implementations
245// ================================================================
246/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
247pub trait Prompt: WasmCompatSend + WasmCompatSync {
248    /// Send a simple prompt to the underlying completion model.
249    ///
250    /// If the completion model's response is a message, then it is returned as a string.
251    ///
252    /// If the completion model's response is a tool call, then the tool is called and
253    /// the result is returned as a string.
254    ///
255    /// If the tool does not exist, or the tool call fails, then an error is returned.
256    fn prompt(
257        &self,
258        prompt: impl Into<Message> + WasmCompatSend,
259    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
260}
261
262/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
263pub trait Chat: WasmCompatSend + WasmCompatSync {
264    /// Send a prompt with optional chat history to the underlying completion model.
265    ///
266    /// If the completion model's response is a message, then it is returned as a string.
267    ///
268    /// If the completion model's response is a tool call, then the tool is called and the result
269    /// is returned as a string.
270    ///
271    /// If the tool does not exist, or the tool call fails, then an error is returned.
272    ///
273    /// The prompt and any assistant or tool messages produced during the turn
274    /// are appended to `chat_history`. Callers should pass the current
275    /// conversation history and should not push the user prompt themselves
276    /// before calling this method.
277    fn chat(
278        &self,
279        prompt: impl Into<Message> + WasmCompatSend,
280        chat_history: &mut Vec<Message>,
281    ) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend;
282}
283
284/// Trait defining a high-level typed prompt interface for structured output.
285///
286/// This trait provides an ergonomic way to get typed responses from an LLM by automatically
287/// generating a JSON schema from the target type and deserializing the response.
288///
289/// # Example
290/// ```rust,ignore
291/// use rig_core::prelude::*;
292/// use schemars::JsonSchema;
293/// use serde::Deserialize;
294///
295/// #[derive(Debug, Deserialize, JsonSchema)]
296/// struct WeatherForecast {
297///     city: String,
298///     temperature_f: f64,
299///     conditions: String,
300/// }
301///
302/// let agent = client.agent("gpt-4o").build();
303/// let forecast: WeatherForecast = agent
304///     .prompt_typed("What's the weather in NYC?")
305///     .await?;
306/// ```
307pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
308    /// The type of the typed prompt request returned by `prompt_typed`.
309    type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
310    where
311        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
312
313    /// Send a prompt and receive a typed structured response.
314    ///
315    /// The JSON schema for `T` is automatically generated and sent to the provider.
316    /// Providers that support native structured outputs will constrain the model's
317    /// response to match this schema.
318    ///
319    /// # Type Parameters
320    /// * `T` - The target type to deserialize the response into. Must implement
321    ///   `JsonSchema` (for schema generation), `DeserializeOwned` (for deserialization),
322    ///   and `WasmCompatSend` (for async compatibility).
323    ///
324    /// # Example
325    /// ```rust,ignore
326    /// // Type can be inferred
327    /// let forecast: WeatherForecast = agent.prompt_typed("What's the weather?").await?;
328    ///
329    /// // Or specified explicitly with turbofish
330    /// let forecast = agent.prompt_typed::<WeatherForecast>("What's the weather?").await?;
331    /// ```
332    fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
333    where
334        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
335}
336
337/// Trait defining a low-level LLM completion interface
338pub trait Completion<M: CompletionModel> {
339    /// Generates a completion request builder for the given `prompt` and `chat_history`.
340    /// This function is meant to be called by the user to further customize the
341    /// request at prompt time before sending it.
342    ///
343    /// ❗IMPORTANT: The type that implements this trait might have already
344    /// populated fields in the builder (the exact fields depend on the type).
345    /// For fields that have already been set by the model, calling the corresponding
346    /// method on the builder will overwrite the value set by the model.
347    ///
348    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
349    /// contain the `preamble` provided when creating the agent.
350    fn completion<I, T>(
351        &self,
352        prompt: impl Into<Message> + WasmCompatSend,
353        chat_history: I,
354    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
355    + WasmCompatSend
356    where
357        I: IntoIterator<Item = T> + WasmCompatSend,
358        T: Into<Message>;
359}
360
361/// General completion response struct that contains the high-level completion choice
362/// and the raw response. The completion choice contains one or more assistant content.
363#[derive(Debug)]
364pub struct CompletionResponse<T> {
365    /// The completion choice (represented by one or more assistant message content)
366    /// returned by the completion model provider
367    pub choice: OneOrMany<AssistantContent>,
368    /// Tokens used during prompting and responding
369    pub usage: Usage,
370    /// The raw response returned by the completion model provider
371    pub raw_response: T,
372    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
373    /// Used to pair reasoning input items with their output items in multi-turn.
374    pub message_id: Option<String>,
375}
376
377/// A trait for grabbing the token usage of a completion response.
378///
379/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
380pub trait GetTokenUsage {
381    /// Returns token usage for this response. Zero-valued usage is
382    /// [`Usage`]'s documented sentinel for missing provider usage metrics;
383    /// response types that carry no usage return [`Usage::new`].
384    fn token_usage(&self) -> crate::completion::Usage;
385}
386
387impl GetTokenUsage for () {
388    fn token_usage(&self) -> crate::completion::Usage {
389        crate::completion::Usage::new()
390    }
391}
392
393impl<T> GetTokenUsage for Option<T>
394where
395    T: GetTokenUsage,
396{
397    fn token_usage(&self) -> crate::completion::Usage {
398        if let Some(usage) = self {
399            usage.token_usage()
400        } else {
401            crate::completion::Usage::new()
402        }
403    }
404}
405
406/// Struct representing the token usage for a completion request.
407/// If tokens used are `0`, then the provider failed to supply token usage metrics.
408#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
409pub struct Usage {
410    /// The number of input ("prompt") tokens used in a given request.
411    pub input_tokens: u64,
412    /// The number of output ("completion") tokens used in a given request.
413    pub output_tokens: u64,
414    /// We store this separately as some providers may only report one number
415    pub total_tokens: u64,
416    /// The number of input tokens read from a provider-managed cache
417    pub cached_input_tokens: u64,
418    /// The number of input tokens written to a provider-managed cache
419    pub cache_creation_input_tokens: u64,
420    /// The number of tool-use prompt tokens used in a given request.
421    #[serde(default)]
422    pub tool_use_prompt_tokens: u64,
423    /// The number of tokens spent on internal reasoning / "thoughts" by reasoning-capable
424    /// models (e.g. Gemini thinking, Anthropic extended thinking, OpenAI o-series).
425    pub reasoning_tokens: u64,
426}
427
428impl Usage {
429    /// Creates a new instance of `Usage`.
430    pub fn new() -> Self {
431        Self {
432            input_tokens: 0,
433            output_tokens: 0,
434            total_tokens: 0,
435            cached_input_tokens: 0,
436            cache_creation_input_tokens: 0,
437            tool_use_prompt_tokens: 0,
438            reasoning_tokens: 0,
439        }
440    }
441
442    /// Whether any usage values are set and non-zero.
443    ///
444    /// Zero-valued usage is this type's documented sentinel for "the provider
445    /// supplied no usage metrics", so `false` means usage was not reported.
446    pub fn has_values(&self) -> bool {
447        *self != Self::new()
448    }
449}
450
451impl Default for Usage {
452    fn default() -> Self {
453        Self::new()
454    }
455}
456
457impl Add for Usage {
458    type Output = Self;
459
460    fn add(self, other: Self) -> Self::Output {
461        Self {
462            input_tokens: self.input_tokens + other.input_tokens,
463            output_tokens: self.output_tokens + other.output_tokens,
464            total_tokens: self.total_tokens + other.total_tokens,
465            cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
466            cache_creation_input_tokens: self.cache_creation_input_tokens
467                + other.cache_creation_input_tokens,
468            tool_use_prompt_tokens: self.tool_use_prompt_tokens + other.tool_use_prompt_tokens,
469            reasoning_tokens: self.reasoning_tokens + other.reasoning_tokens,
470        }
471    }
472}
473
474impl AddAssign for Usage {
475    fn add_assign(&mut self, other: Self) {
476        self.input_tokens += other.input_tokens;
477        self.output_tokens += other.output_tokens;
478        self.total_tokens += other.total_tokens;
479        self.cached_input_tokens += other.cached_input_tokens;
480        self.cache_creation_input_tokens += other.cache_creation_input_tokens;
481        self.tool_use_prompt_tokens += other.tool_use_prompt_tokens;
482        self.reasoning_tokens += other.reasoning_tokens;
483    }
484}
485
486/// Trait defining a completion model that can be used to generate completion responses.
487/// This trait is meant to be implemented by the user to define a custom completion model,
488/// either from a third party provider (e.g.: OpenAI) or a local model.
489pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
490    /// The raw response type returned by the underlying completion model.
491    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
492    /// The raw response type returned by the underlying completion model when streaming.
493    type StreamingResponse: Clone
494        + Unpin
495        + WasmCompatSend
496        + WasmCompatSync
497        + Serialize
498        + DeserializeOwned
499        + GetTokenUsage;
500
501    /// Provider client type used to construct this model.
502    type Client;
503
504    /// Construct a model handle from a provider client and model identifier.
505    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
506
507    /// Generates a completion response for the given completion request.
508    fn completion(
509        &self,
510        request: CompletionRequest,
511    ) -> impl std::future::Future<
512        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
513    > + WasmCompatSend;
514
515    fn stream(
516        &self,
517        request: CompletionRequest,
518    ) -> impl std::future::Future<
519        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
520    > + WasmCompatSend;
521
522    /// Generates a completion request builder for the given `prompt`.
523    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
524        CompletionRequestBuilder::new(self.clone(), prompt)
525    }
526}
527
528/// Struct representing a general completion request that can be sent to a completion model provider.
529#[derive(Debug, Clone, Serialize, Deserialize)]
530pub struct CompletionRequest {
531    /// Optional model override for this request.
532    pub model: Option<String>,
533    /// Legacy preamble field preserved for backwards compatibility.
534    ///
535    /// New code should prefer a leading [`Message::System`]
536    /// in `chat_history` as the canonical representation of system instructions.
537    pub preamble: Option<String>,
538    /// The chat history to be sent to the completion model provider.
539    /// The very last message will always be the prompt (hence why there is *always* one)
540    pub chat_history: OneOrMany<Message>,
541    /// The documents to be sent to the completion model provider
542    pub documents: Vec<Document>,
543    /// The tools to be sent to the completion model provider
544    pub tools: Vec<ToolDefinition>,
545    /// The temperature to be sent to the completion model provider
546    pub temperature: Option<f64>,
547    /// The max tokens to be sent to the completion model provider
548    pub max_tokens: Option<u64>,
549    /// Whether tools are required to be used by the model provider or not before providing a response.
550    pub tool_choice: Option<ToolChoice>,
551    /// Additional provider-specific parameters to be sent to the completion model provider
552    pub additional_params: Option<serde_json::Value>,
553    /// Optional JSON Schema for structured output. When set, providers that support
554    /// native structured outputs will constrain the model's response to match this schema.
555    pub output_schema: Option<schemars::Schema>,
556}
557
558impl CompletionRequest {
559    /// Extracts a name from the output schema's `"title"` field, falling back to `"response_schema"`.
560    /// Useful for providers that require a name alongside the JSON Schema (e.g., OpenAI).
561    pub fn output_schema_name(&self) -> Option<String> {
562        self.output_schema.as_ref().map(|schema| {
563            schema
564                .as_object()
565                .and_then(|o| o.get("title"))
566                .and_then(|v| v.as_str())
567                .unwrap_or("response_schema")
568                .to_string()
569        })
570    }
571
572    /// Returns documents normalized into a message (if any).
573    /// Most providers do not accept documents directly as input, so it needs to convert into a
574    /// `Message` so that it can be incorporated into `chat_history`.
575    pub fn normalized_documents(&self) -> Option<Message> {
576        Self::normalized_documents_from(&self.documents)
577    }
578
579    fn normalized_documents_from(documents: &[Document]) -> Option<Message> {
580        if documents.is_empty() {
581            return None;
582        }
583
584        // Most providers will convert documents into a text unless it can handle document messages.
585        // We use `UserContent::document` for those who handle it directly!
586        let messages = documents
587            .iter()
588            .map(|doc| {
589                UserContent::document(
590                    doc.to_string(),
591                    // In the future, we can customize `Document` to pass these extra types through.
592                    // Most providers ditch these but they might want to use them.
593                    Some(DocumentMediaType::TXT),
594                )
595            })
596            .collect::<Vec<_>>();
597
598        OneOrMany::from_iter_optional(messages).map(|content| Message::User { content })
599    }
600
601    pub(crate) fn chat_history_with_documents(&self) -> Vec<Message> {
602        let mut chat_history = self.chat_history.iter().cloned().collect::<Vec<_>>();
603        if let Some(documents) = self.normalized_documents() {
604            let insert_at = chat_history
605                .iter()
606                .position(|message| !matches!(message, Message::System { .. }))
607                .unwrap_or(chat_history.len());
608            chat_history.insert(insert_at, documents);
609        }
610        chat_history
611    }
612
613    /// Adds a provider-hosted tool by storing it in `additional_params.tools`.
614    pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
615        self.additional_params =
616            merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
617        self
618    }
619
620    /// Adds provider-hosted tools by storing them in `additional_params.tools`.
621    pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
622        self.additional_params =
623            merge_provider_tools_into_additional_params(self.additional_params, tools);
624        self
625    }
626}
627
628fn merge_provider_tools_into_additional_params(
629    additional_params: Option<serde_json::Value>,
630    provider_tools: Vec<ProviderToolDefinition>,
631) -> Option<serde_json::Value> {
632    if provider_tools.is_empty() {
633        return additional_params;
634    }
635
636    let mut provider_tools_json = provider_tools
637        .into_iter()
638        .map(|ProviderToolDefinition { kind, mut config }| {
639            // Force the provider tool type from the strongly-typed field.
640            config.insert("type".to_string(), serde_json::Value::String(kind));
641            serde_json::Value::Object(config)
642        })
643        .collect::<Vec<_>>();
644
645    let mut params_map = match additional_params {
646        Some(serde_json::Value::Object(map)) => map,
647        Some(serde_json::Value::Bool(stream)) => {
648            let mut map = serde_json::Map::new();
649            map.insert("stream".to_string(), serde_json::Value::Bool(stream));
650            map
651        }
652        _ => serde_json::Map::new(),
653    };
654
655    let mut merged_tools = match params_map.remove("tools") {
656        Some(serde_json::Value::Array(existing)) => existing,
657        _ => Vec::new(),
658    };
659    merged_tools.append(&mut provider_tools_json);
660    params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
661    Some(serde_json::Value::Object(params_map))
662}
663
664/// Builder struct for constructing a completion request.
665///
666/// Example usage:
667/// ```no_run
668/// use rig_core::{
669///     client::CompletionClient,
670///     providers::openai::{Client, self},
671///     completion::{CompletionModel, CompletionRequestBuilder},
672/// };
673///
674/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
675/// let openai = Client::new("your-openai-api-key")?;
676/// let model = openai.completion_model(openai::GPT_5_2);
677///
678/// // Create the completion request and execute it separately
679/// let request = CompletionRequestBuilder::new(model.clone(), "Who are you?".to_string())
680///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
681///     .temperature(0.5)
682///     .build();
683///
684/// let response = model.completion(request).await?;
685/// # Ok(())
686/// # }
687/// ```
688///
689/// Alternatively, you can execute the completion request directly from the builder:
690/// ```no_run
691/// use rig_core::{
692///     client::CompletionClient,
693///     providers::openai::{Client, self},
694///     completion::CompletionRequestBuilder,
695/// };
696///
697/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
698/// let openai = Client::new("your-openai-api-key")?;
699/// let model = openai.completion_model(openai::GPT_5_2);
700///
701/// // Create the completion request and execute it directly
702/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
703///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
704///     .temperature(0.5)
705///     .send()
706///     .await?;
707/// # Ok(())
708/// # }
709/// ```
710///
711/// Note: It is usually unnecessary to create a completion request builder directly.
712/// Instead, use the [CompletionModel::completion_request] method.
713pub struct CompletionRequestBuilder<M: CompletionModel> {
714    model: M,
715    prompt: Message,
716    request_model: Option<String>,
717    preamble: Option<String>,
718    chat_history: Vec<Message>,
719    documents: Vec<Document>,
720    tools: Vec<ToolDefinition>,
721    provider_tools: Vec<ProviderToolDefinition>,
722    temperature: Option<f64>,
723    max_tokens: Option<u64>,
724    tool_choice: Option<ToolChoice>,
725    additional_params: Option<serde_json::Value>,
726    output_schema: Option<schemars::Schema>,
727}
728
729impl<M: CompletionModel> CompletionRequestBuilder<M> {
730    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
731        Self {
732            model,
733            prompt: prompt.into(),
734            request_model: None,
735            preamble: None,
736            chat_history: Vec::new(),
737            documents: Vec::new(),
738            tools: Vec::new(),
739            provider_tools: Vec::new(),
740            temperature: None,
741            max_tokens: None,
742            tool_choice: None,
743            additional_params: None,
744            output_schema: None,
745        }
746    }
747
748    /// Sets the preamble for the completion request.
749    pub fn preamble(mut self, preamble: String) -> Self {
750        // Legacy public API: funnel preamble into canonical system messages at build-time.
751        self.preamble = Some(preamble);
752        self
753    }
754
755    /// Overrides the model used for this request.
756    pub fn model(mut self, model: impl Into<String>) -> Self {
757        self.request_model = Some(model.into());
758        self
759    }
760
761    /// Overrides the model used for this request.
762    pub fn model_opt(mut self, model: Option<String>) -> Self {
763        self.request_model = model;
764        self
765    }
766
767    pub fn without_preamble(mut self) -> Self {
768        self.preamble = None;
769        self
770    }
771
772    /// Adds a message to the chat history for the completion request.
773    pub fn message(mut self, message: Message) -> Self {
774        self.chat_history.push(message);
775
776        self
777    }
778
779    /// Adds a list of messages to the chat history for the completion request.
780    pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
781        self.chat_history.extend(messages);
782
783        self
784    }
785
786    /// Adds a document to the completion request.
787    pub fn document(mut self, document: Document) -> Self {
788        self.documents.push(document);
789        self
790    }
791
792    /// Adds a list of documents to the completion request.
793    pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
794        documents
795            .into_iter()
796            .fold(self, |builder, doc| builder.document(doc))
797    }
798
799    /// Adds a tool to the completion request.
800    pub fn tool(mut self, tool: ToolDefinition) -> Self {
801        self.tools.push(tool);
802        self
803    }
804
805    /// Adds a list of tools to the completion request.
806    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
807        tools
808            .into_iter()
809            .fold(self, |builder, tool| builder.tool(tool))
810    }
811
812    /// Adds a provider-hosted tool to the completion request.
813    pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
814        self.provider_tools.push(tool);
815        self
816    }
817
818    /// Adds provider-hosted tools to the completion request.
819    pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
820        tools
821            .into_iter()
822            .fold(self, |builder, tool| builder.provider_tool(tool))
823    }
824
825    /// Adds additional parameters to the completion request.
826    /// This can be used to set additional provider-specific parameters. For example,
827    /// Cohere's completion models accept a `connectors` parameter that can be used to
828    /// specify the data connectors used by Cohere when executing the completion
829    /// (see `examples/cohere_connectors.rs`).
830    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
831        match self.additional_params {
832            Some(params) => {
833                self.additional_params = Some(json_utils::merge(params, additional_params));
834            }
835            None => {
836                self.additional_params = Some(additional_params);
837            }
838        }
839        self
840    }
841
842    /// Sets the additional parameters for the completion request.
843    /// This can be used to set additional provider-specific parameters. For example,
844    /// Cohere's completion models accept a `connectors` parameter that can be used to
845    /// specify the data connectors used by Cohere when executing the completion
846    /// (see `examples/cohere_connectors.rs`).
847    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
848        self.additional_params = additional_params;
849        self
850    }
851
852    /// Sets the temperature for the completion request.
853    pub fn temperature(mut self, temperature: f64) -> Self {
854        self.temperature = Some(temperature);
855        self
856    }
857
858    /// Sets the temperature for the completion request.
859    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
860        self.temperature = temperature;
861        self
862    }
863
864    /// Sets the max tokens for the completion request.
865    /// Note: This is required if using Anthropic
866    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
867        self.max_tokens = Some(max_tokens);
868        self
869    }
870
871    /// Sets the max tokens for the completion request.
872    /// Note: This is required if using Anthropic
873    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
874        self.max_tokens = max_tokens;
875        self
876    }
877
878    /// Sets the thing.
879    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
880        self.tool_choice = Some(tool_choice);
881        self
882    }
883
884    /// Sets the output schema for structured output. When set, providers that support
885    /// native structured outputs will constrain the model's response to match this schema.
886    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
887    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
888    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
889    /// to still be able to leverage structured outputs.
890    pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
891        self.output_schema = Some(schema);
892        self
893    }
894
895    /// Sets the output schema for structured output from an optional value.
896    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
897    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
898    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
899    /// to still be able to leverage structured outputs.
900    pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
901        self.output_schema = schema;
902        self
903    }
904
905    /// Builds the completion request.
906    pub fn build(self) -> CompletionRequest {
907        // Build the final message list, prepending preamble if present
908        let mut chat_history = self.chat_history;
909        let prompt = self.prompt;
910        if let Some(preamble) = self.preamble {
911            chat_history.insert(0, Message::system(preamble));
912        }
913
914        chat_history.push(prompt.clone());
915
916        let chat_history =
917            OneOrMany::from_iter_optional(chat_history).unwrap_or_else(|| OneOrMany::one(prompt));
918        let additional_params = merge_provider_tools_into_additional_params(
919            self.additional_params,
920            self.provider_tools,
921        );
922
923        CompletionRequest {
924            model: self.request_model,
925            preamble: None,
926            chat_history,
927            documents: self.documents,
928            tools: self.tools,
929            temperature: self.temperature,
930            max_tokens: self.max_tokens,
931            tool_choice: self.tool_choice,
932            additional_params,
933            output_schema: self.output_schema,
934        }
935    }
936
937    /// Sends the completion request to the completion model provider and returns the completion response.
938    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
939        let model = self.model.clone();
940        model.completion(self.build()).await
941    }
942
943    /// Stream the completion request
944    pub async fn stream<'a>(
945        self,
946    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
947    where
948        <M as CompletionModel>::StreamingResponse: 'a,
949        Self: 'a,
950    {
951        let model = self.model.clone();
952        model.stream(self.build()).await
953    }
954}
955
956#[cfg(test)]
957mod tests {
958    #[test]
959    fn usage_has_values_reflects_the_zero_sentinel() {
960        use super::Usage;
961
962        assert!(!Usage::new().has_values());
963
964        let mut usage = Usage::new();
965        usage.reasoning_tokens = 1;
966        assert!(usage.has_values());
967    }
968
969    use super::*;
970    use crate::test_utils::MockCompletionModel;
971
972    fn test_document(id: &str, text: &str) -> Document {
973        Document {
974            id: id.to_string(),
975            text: text.to_string(),
976            additional_props: HashMap::new(),
977        }
978    }
979
980    fn is_document_message(message: &Message, expected_id: &str) -> bool {
981        let Message::User { content } = message else {
982            return false;
983        };
984
985        content.iter().any(|content| {
986            matches!(
987                content,
988                UserContent::Document(document)
989                    if document.data.to_string().contains(&format!("<file id: {expected_id}>"))
990            )
991        })
992    }
993
994    #[test]
995    fn test_document_display_without_metadata() {
996        let doc = Document {
997            id: "123".to_string(),
998            text: "This is a test document.".to_string(),
999            additional_props: HashMap::new(),
1000        };
1001
1002        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
1003        assert_eq!(format!("{doc}"), expected);
1004    }
1005
1006    #[test]
1007    fn test_document_display_with_metadata() {
1008        let mut additional_props = HashMap::new();
1009        additional_props.insert("author".to_string(), "John Doe".to_string());
1010        additional_props.insert("length".to_string(), "42".to_string());
1011
1012        let doc = Document {
1013            id: "123".to_string(),
1014            text: "This is a test document.".to_string(),
1015            additional_props,
1016        };
1017
1018        let expected = concat!(
1019            "<file id: 123>\n",
1020            "<metadata author: \"John Doe\" length: \"42\" />\n",
1021            "This is a test document.\n",
1022            "</file>\n"
1023        );
1024        assert_eq!(format!("{doc}"), expected);
1025    }
1026
1027    #[test]
1028    fn test_normalize_documents_with_documents() {
1029        let doc1 = Document {
1030            id: "doc1".to_string(),
1031            text: "Document 1 text.".to_string(),
1032            additional_props: HashMap::new(),
1033        };
1034
1035        let doc2 = Document {
1036            id: "doc2".to_string(),
1037            text: "Document 2 text.".to_string(),
1038            additional_props: HashMap::new(),
1039        };
1040
1041        let request = CompletionRequest {
1042            model: None,
1043            preamble: None,
1044            chat_history: OneOrMany::one("What is the capital of France?".into()),
1045            documents: vec![doc1, doc2],
1046            tools: Vec::new(),
1047            temperature: None,
1048            max_tokens: None,
1049            tool_choice: None,
1050            additional_params: None,
1051            output_schema: None,
1052        };
1053
1054        let expected = Message::User {
1055            content: OneOrMany::many(vec![
1056                UserContent::document(
1057                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1058                    Some(DocumentMediaType::TXT),
1059                ),
1060                UserContent::document(
1061                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1062                    Some(DocumentMediaType::TXT),
1063                ),
1064            ])
1065            .expect("There will be at least one document"),
1066        };
1067
1068        assert_eq!(request.normalized_documents(), Some(expected));
1069    }
1070
1071    #[test]
1072    fn test_normalize_documents_without_documents() {
1073        let request = CompletionRequest {
1074            model: None,
1075            preamble: None,
1076            chat_history: OneOrMany::one("What is the capital of France?".into()),
1077            documents: Vec::new(),
1078            tools: Vec::new(),
1079            temperature: None,
1080            max_tokens: None,
1081            tool_choice: None,
1082            additional_params: None,
1083            output_schema: None,
1084        };
1085
1086        assert_eq!(request.normalized_documents(), None);
1087    }
1088
1089    #[test]
1090    fn preamble_builder_funnels_to_system_message() {
1091        let request =
1092            CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1093                .preamble("System prompt".to_string())
1094                .message(Message::user("History"))
1095                .build();
1096
1097        assert_eq!(request.preamble, None);
1098
1099        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1100        assert_eq!(history.len(), 3);
1101        assert!(matches!(
1102            &history[0],
1103            Message::System { content } if content == "System prompt"
1104        ));
1105        assert!(matches!(&history[1], Message::User { .. }));
1106        assert!(matches!(&history[2], Message::User { .. }));
1107    }
1108
1109    #[test]
1110    fn without_preamble_removes_legacy_preamble_injection() {
1111        let request =
1112            CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1113                .preamble("System prompt".to_string())
1114                .without_preamble()
1115                .build();
1116
1117        assert_eq!(request.preamble, None);
1118        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1119        assert_eq!(history.len(), 1);
1120        assert!(matches!(&history[0], Message::User { .. }));
1121    }
1122
1123    #[test]
1124    fn build_places_documents_after_preamble_system_message() {
1125        let request =
1126            CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1127                .preamble("System prompt".to_string())
1128                .document(test_document("doc1", "Document text."))
1129                .build();
1130
1131        assert_eq!(request.documents.len(), 1);
1132
1133        let history = request.chat_history_with_documents();
1134        let history = history.iter().collect::<Vec<_>>();
1135        assert_eq!(history.len(), 3);
1136        assert!(matches!(
1137            history[0],
1138            Message::System { content } if content == "System prompt"
1139        ));
1140        assert!(is_document_message(history[1], "doc1"));
1141        assert!(matches!(history[2], Message::User { .. }));
1142    }
1143
1144    #[test]
1145    fn build_places_documents_after_leading_system_messages_before_prior_history() {
1146        let request =
1147            CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1148                .message(Message::system("System one"))
1149                .message(Message::system("System two"))
1150                .message(Message::user("Earlier user turn"))
1151                .message(Message::assistant("Earlier assistant turn"))
1152                .document(test_document("doc1", "Document text."))
1153                .build();
1154
1155        let history = request.chat_history_with_documents();
1156        let history = history.iter().collect::<Vec<_>>();
1157        assert_eq!(history.len(), 6);
1158        assert!(matches!(
1159            history[0],
1160            Message::System { content } if content == "System one"
1161        ));
1162        assert!(matches!(
1163            history[1],
1164            Message::System { content } if content == "System two"
1165        ));
1166        assert!(is_document_message(history[2], "doc1"));
1167        assert!(matches!(history[3], Message::User { .. }));
1168        assert!(matches!(history[4], Message::Assistant { .. }));
1169        assert!(matches!(history[5], Message::User { .. }));
1170    }
1171
1172    #[test]
1173    fn build_without_documents_keeps_message_order_unchanged() {
1174        let request =
1175            CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1176                .message(Message::system("System prompt"))
1177                .message(Message::user("Earlier user turn"))
1178                .build();
1179
1180        let history = request.chat_history.iter().collect::<Vec<_>>();
1181        assert_eq!(history.len(), 3);
1182        assert!(matches!(
1183            history[0],
1184            Message::System { content } if content == "System prompt"
1185        ));
1186        assert!(matches!(history[1], Message::User { .. }));
1187        assert!(matches!(history[2], Message::User { .. }));
1188    }
1189
1190    #[test]
1191    fn chat_history_with_documents_places_documents_after_leading_system_messages() {
1192        let request = CompletionRequest {
1193            model: None,
1194            preamble: None,
1195            chat_history: OneOrMany::many(vec![
1196                Message::system("System prompt"),
1197                Message::assistant("Earlier assistant turn"),
1198                Message::user("Earlier user turn"),
1199                Message::user("Prompt"),
1200            ])
1201            .unwrap(),
1202            documents: vec![test_document("doc1", "Document text.")],
1203            tools: Vec::new(),
1204            temperature: None,
1205            max_tokens: None,
1206            tool_choice: None,
1207            additional_params: None,
1208            output_schema: None,
1209        };
1210
1211        assert_eq!(request.documents.len(), 1);
1212
1213        let history = request.chat_history_with_documents();
1214        let history = history.iter().collect::<Vec<_>>();
1215        assert_eq!(history.len(), 5);
1216        assert!(matches!(history[0], Message::System { .. }));
1217        assert!(is_document_message(history[1], "doc1"));
1218        assert!(matches!(history[2], Message::Assistant { .. }));
1219        assert!(matches!(history[3], Message::User { .. }));
1220        assert!(matches!(history[4], Message::User { .. }));
1221    }
1222
1223    #[test]
1224    fn chat_history_with_documents_places_documents_before_mid_conversation_system_messages() {
1225        let request = CompletionRequest {
1226            model: None,
1227            preamble: None,
1228            chat_history: OneOrMany::many(vec![
1229                Message::system("Leading system prompt"),
1230                Message::assistant("Earlier assistant turn"),
1231                Message::system("Mid-conversation instruction"),
1232                Message::user("Prompt"),
1233            ])
1234            .unwrap(),
1235            documents: vec![test_document("doc1", "Document text.")],
1236            tools: Vec::new(),
1237            temperature: None,
1238            max_tokens: None,
1239            tool_choice: None,
1240            additional_params: None,
1241            output_schema: None,
1242        };
1243
1244        let history = request.chat_history_with_documents();
1245        let history = history.iter().collect::<Vec<_>>();
1246        assert_eq!(history.len(), 5);
1247        assert!(matches!(
1248            history[0],
1249            Message::System { content } if content == "Leading system prompt"
1250        ));
1251        assert!(is_document_message(history[1], "doc1"));
1252        assert!(matches!(history[2], Message::Assistant { .. }));
1253        assert!(matches!(
1254            history[3],
1255            Message::System { content } if content == "Mid-conversation instruction"
1256        ));
1257        assert!(matches!(history[4], Message::User { .. }));
1258    }
1259
1260    #[test]
1261    fn chat_history_with_documents_does_not_duplicate_documents() {
1262        let request = CompletionRequest {
1263            model: None,
1264            preamble: None,
1265            chat_history: OneOrMany::many(vec![
1266                Message::system("System prompt"),
1267                Message::user("Earlier user turn"),
1268                Message::assistant("Earlier assistant turn"),
1269                Message::user("Prompt"),
1270            ])
1271            .unwrap(),
1272            documents: vec![test_document("doc1", "Document text.")],
1273            tools: Vec::new(),
1274            temperature: None,
1275            max_tokens: None,
1276            tool_choice: None,
1277            additional_params: None,
1278            output_schema: None,
1279        };
1280
1281        let history = request.chat_history_with_documents();
1282        let document_messages = history
1283            .iter()
1284            .filter(|message| is_document_message(message, "doc1"))
1285            .count();
1286        assert_eq!(document_messages, 1);
1287    }
1288}