Skip to main content

rig_core/completion/
request.rs

1//! Completion request, response, and provider trait definitions.
2//!
3//! Most applications use [`Prompt`] or [`Chat`] through
4//! [`Agent`](crate::agent::Agent). Provider integrations implement
5//! [`CompletionModel`] and translate [`CompletionRequest`] into their native HTTP
6//! request format.
7//!
8//! # Low-level request example
9//!
10//! ```no_run
11//! use rig_core::{
12//!     client::{CompletionClient, ProviderClient},
13//!     completion::{AssistantContent, CompletionModel},
14//!     providers::openai,
15//! };
16//!
17//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
18//! let client = openai::Client::from_env()?;
19//! let model = client.completion_model(openai::GPT_5_2);
20//!
21//! let request = model
22//!     .completion_request("Who are you?")
23//!     .preamble("You are a concise assistant.".to_string())
24//!     .temperature(0.5)
25//!     .build();
26//!
27//! let response = model.completion(request).await?;
28//! for item in response.choice {
29//!     if let AssistantContent::Text(text) = item {
30//!         println!("{}", text.text);
31//!     }
32//! }
33//! # Ok(())
34//! # }
35//! ```
36
37use super::message::{AssistantContent, DocumentMediaType};
38use crate::message::ToolChoice;
39use crate::streaming::StreamingCompletionResponse;
40use crate::tool::server::ToolServerError;
41use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
42use crate::{OneOrMany, http_client};
43use crate::{
44    json_utils,
45    message::{Message, UserContent},
46    tool::ToolSetError,
47};
48use serde::de::DeserializeOwned;
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use std::ops::{Add, AddAssign};
52use thiserror::Error;
53
54// Errors
55#[derive(Debug, Error)]
56pub enum CompletionError {
57    /// Http error (e.g.: connection error, timeout, etc.)
58    #[error("HttpError: {0}")]
59    HttpError(#[from] http_client::Error),
60
61    /// Json error (e.g.: serialization, deserialization)
62    #[error("JsonError: {0}")]
63    JsonError(#[from] serde_json::Error),
64
65    /// Url error (e.g.: invalid URL)
66    #[error("UrlError: {0}")]
67    UrlError(#[from] url::ParseError),
68
69    #[cfg(not(target_family = "wasm"))]
70    /// Error building the completion request
71    #[error("RequestError: {0}")]
72    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
73
74    #[cfg(target_family = "wasm")]
75    /// Error building the completion request
76    #[error("RequestError: {0}")]
77    RequestError(#[from] Box<dyn std::error::Error + 'static>),
78
79    /// Error parsing the completion response
80    #[error("ResponseError: {0}")]
81    ResponseError(String),
82
83    /// Error returned by the completion model provider
84    #[error("ProviderError: {0}")]
85    ProviderError(String),
86}
87
88/// Prompt errors
89#[derive(Debug, Error)]
90pub enum PromptError {
91    /// Something went wrong with the completion
92    #[error("CompletionError: {0}")]
93    CompletionError(#[from] CompletionError),
94
95    /// There was an error while using a tool
96    #[error("ToolCallError: {0}")]
97    ToolError(#[from] ToolSetError),
98
99    /// There was an issue while executing a tool on a tool server
100    #[error("ToolServerError: {0}")]
101    ToolServerError(#[from] Box<ToolServerError>),
102
103    /// The LLM tried to call too many tools during a multi-turn conversation.
104    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
105    /// or increase the amount of turns given in `.multi_turn()`.
106    #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
107    MaxTurnsError {
108        max_turns: usize,
109        chat_history: Box<Vec<Message>>,
110        prompt: Box<Message>,
111    },
112
113    /// A prompting loop was cancelled.
114    #[error("PromptCancelled: {reason}")]
115    PromptCancelled {
116        chat_history: Vec<Message>,
117        reason: String,
118    },
119}
120
121/// Surface [`crate::memory::ConversationMemory`] failures through the existing
122/// [`CompletionError::RequestError`] variant so adding memory support does not
123/// require a new top-level [`PromptError`] arm in downstream exhaustive matchers.
124impl From<crate::memory::MemoryError> for PromptError {
125    fn from(err: crate::memory::MemoryError) -> Self {
126        Self::CompletionError(CompletionError::RequestError(Box::new(err)))
127    }
128}
129
130impl PromptError {
131    pub(crate) fn prompt_cancelled(
132        chat_history: impl IntoIterator<Item = Message>,
133        reason: impl Into<String>,
134    ) -> Self {
135        Self::PromptCancelled {
136            chat_history: chat_history.into_iter().collect(),
137            reason: reason.into(),
138        }
139    }
140}
141
142/// Errors that can occur when using typed structured output via [`TypedPrompt::prompt_typed`].
143#[derive(Debug, Error)]
144pub enum StructuredOutputError {
145    /// An error occurred during the prompt execution.
146    #[error("PromptError: {0}")]
147    PromptError(#[from] Box<PromptError>),
148
149    /// Failed to deserialize the model's response into the target type.
150    #[error("DeserializationError: {0}")]
151    DeserializationError(#[from] serde_json::Error),
152
153    /// The model returned an empty response.
154    #[error("EmptyResponse: model returned no content")]
155    EmptyResponse,
156}
157
158#[derive(Clone, Debug, Deserialize, Serialize)]
159pub struct Document {
160    /// Stable document identifier included in the serialized context block.
161    pub id: String,
162    /// Text content passed to the model as retrieval or static context.
163    pub text: String,
164    /// Additional string metadata rendered before the document text.
165    #[serde(flatten)]
166    pub additional_props: HashMap<String, String>,
167}
168
169impl std::fmt::Display for Document {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        write!(
172            f,
173            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
174            self.id,
175            if self.additional_props.is_empty() {
176                self.text.clone()
177            } else {
178                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
179                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
180                let metadata = sorted_props
181                    .iter()
182                    .map(|(k, v)| format!("{k}: {v:?}"))
183                    .collect::<Vec<_>>()
184                    .join(" ");
185                format!("<metadata {} />\n{}", metadata, self.text)
186            }
187        )
188    }
189}
190
191#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
192pub struct ToolDefinition {
193    /// Tool name exposed to the model. It must match the registered tool name.
194    pub name: String,
195    /// Human-readable description sent to the model.
196    pub description: String,
197    /// JSON Schema describing tool arguments.
198    pub parameters: serde_json::Value,
199}
200
201/// Provider-native tool definition.
202///
203/// Stored under `additional_params.tools` and forwarded by providers that support
204/// provider-managed tools.
205#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
206pub struct ProviderToolDefinition {
207    /// Tool type/kind name as expected by the target provider (for example `web_search`).
208    #[serde(rename = "type")]
209    pub kind: String,
210    /// Additional provider-specific configuration for this hosted tool.
211    #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
212    pub config: serde_json::Map<String, serde_json::Value>,
213}
214
215impl ProviderToolDefinition {
216    /// Creates a provider-hosted tool definition by type.
217    pub fn new(kind: impl Into<String>) -> Self {
218        Self {
219            kind: kind.into(),
220            config: serde_json::Map::new(),
221        }
222    }
223
224    /// Adds a provider-specific configuration key/value.
225    pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
226        self.config.insert(key.into(), value);
227        self
228    }
229}
230
231// ================================================================
232// Implementations
233// ================================================================
234/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
235pub trait Prompt: WasmCompatSend + WasmCompatSync {
236    /// Send a simple prompt to the underlying completion model.
237    ///
238    /// If the completion model's response is a message, then it is returned as a string.
239    ///
240    /// If the completion model's response is a tool call, then the tool is called and
241    /// the result is returned as a string.
242    ///
243    /// If the tool does not exist, or the tool call fails, then an error is returned.
244    fn prompt(
245        &self,
246        prompt: impl Into<Message> + WasmCompatSend,
247    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
248}
249
250/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
251pub trait Chat: WasmCompatSend + WasmCompatSync {
252    /// Send a prompt with optional chat history to the underlying completion model.
253    ///
254    /// If the completion model's response is a message, then it is returned as a string.
255    ///
256    /// If the completion model's response is a tool call, then the tool is called and the result
257    /// is returned as a string.
258    ///
259    /// If the tool does not exist, or the tool call fails, then an error is returned.
260    ///
261    /// The prompt and any assistant or tool messages produced during the turn
262    /// are appended to `chat_history`. Callers should pass the current
263    /// conversation history and should not push the user prompt themselves
264    /// before calling this method.
265    fn chat(
266        &self,
267        prompt: impl Into<Message> + WasmCompatSend,
268        chat_history: &mut Vec<Message>,
269    ) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend;
270}
271
272/// Trait defining a high-level typed prompt interface for structured output.
273///
274/// This trait provides an ergonomic way to get typed responses from an LLM by automatically
275/// generating a JSON schema from the target type and deserializing the response.
276///
277/// # Example
278/// ```rust,ignore
279/// use rig_core::prelude::*;
280/// use schemars::JsonSchema;
281/// use serde::Deserialize;
282///
283/// #[derive(Debug, Deserialize, JsonSchema)]
284/// struct WeatherForecast {
285///     city: String,
286///     temperature_f: f64,
287///     conditions: String,
288/// }
289///
290/// let agent = client.agent("gpt-4o").build();
291/// let forecast: WeatherForecast = agent
292///     .prompt_typed("What's the weather in NYC?")
293///     .await?;
294/// ```
295pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
296    /// The type of the typed prompt request returned by `prompt_typed`.
297    type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
298    where
299        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
300
301    /// Send a prompt and receive a typed structured response.
302    ///
303    /// The JSON schema for `T` is automatically generated and sent to the provider.
304    /// Providers that support native structured outputs will constrain the model's
305    /// response to match this schema.
306    ///
307    /// # Type Parameters
308    /// * `T` - The target type to deserialize the response into. Must implement
309    ///   `JsonSchema` (for schema generation), `DeserializeOwned` (for deserialization),
310    ///   and `WasmCompatSend` (for async compatibility).
311    ///
312    /// # Example
313    /// ```rust,ignore
314    /// // Type can be inferred
315    /// let forecast: WeatherForecast = agent.prompt_typed("What's the weather?").await?;
316    ///
317    /// // Or specified explicitly with turbofish
318    /// let forecast = agent.prompt_typed::<WeatherForecast>("What's the weather?").await?;
319    /// ```
320    fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
321    where
322        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
323}
324
325/// Trait defining a low-level LLM completion interface
326pub trait Completion<M: CompletionModel> {
327    /// Generates a completion request builder for the given `prompt` and `chat_history`.
328    /// This function is meant to be called by the user to further customize the
329    /// request at prompt time before sending it.
330    ///
331    /// ❗IMPORTANT: The type that implements this trait might have already
332    /// populated fields in the builder (the exact fields depend on the type).
333    /// For fields that have already been set by the model, calling the corresponding
334    /// method on the builder will overwrite the value set by the model.
335    ///
336    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
337    /// contain the `preamble` provided when creating the agent.
338    fn completion<I, T>(
339        &self,
340        prompt: impl Into<Message> + WasmCompatSend,
341        chat_history: I,
342    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
343    + WasmCompatSend
344    where
345        I: IntoIterator<Item = T> + WasmCompatSend,
346        T: Into<Message>;
347}
348
349/// General completion response struct that contains the high-level completion choice
350/// and the raw response. The completion choice contains one or more assistant content.
351#[derive(Debug)]
352pub struct CompletionResponse<T> {
353    /// The completion choice (represented by one or more assistant message content)
354    /// returned by the completion model provider
355    pub choice: OneOrMany<AssistantContent>,
356    /// Tokens used during prompting and responding
357    pub usage: Usage,
358    /// The raw response returned by the completion model provider
359    pub raw_response: T,
360    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
361    /// Used to pair reasoning input items with their output items in multi-turn.
362    pub message_id: Option<String>,
363}
364
365/// A trait for grabbing the token usage of a completion response.
366///
367/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
368pub trait GetTokenUsage {
369    /// Returns token usage when the response type carries it.
370    fn token_usage(&self) -> Option<crate::completion::Usage>;
371}
372
373impl GetTokenUsage for () {
374    fn token_usage(&self) -> Option<crate::completion::Usage> {
375        None
376    }
377}
378
379impl<T> GetTokenUsage for Option<T>
380where
381    T: GetTokenUsage,
382{
383    fn token_usage(&self) -> Option<crate::completion::Usage> {
384        if let Some(usage) = self {
385            usage.token_usage()
386        } else {
387            None
388        }
389    }
390}
391
392/// Struct representing the token usage for a completion request.
393/// If tokens used are `0`, then the provider failed to supply token usage metrics.
394#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
395pub struct Usage {
396    /// The number of input ("prompt") tokens used in a given request.
397    pub input_tokens: u64,
398    /// The number of output ("completion") tokens used in a given request.
399    pub output_tokens: u64,
400    /// We store this separately as some providers may only report one number
401    pub total_tokens: u64,
402    /// The number of input tokens read from a provider-managed cache
403    pub cached_input_tokens: u64,
404    /// The number of input tokens written to a provider-managed cache
405    pub cache_creation_input_tokens: u64,
406    /// The number of tokens spent on internal reasoning / "thoughts" by reasoning-capable
407    /// models (e.g. Gemini thinking, Anthropic extended thinking, OpenAI o-series).
408    pub reasoning_tokens: u64,
409}
410
411impl Usage {
412    /// Creates a new instance of `Usage`.
413    pub fn new() -> Self {
414        Self {
415            input_tokens: 0,
416            output_tokens: 0,
417            total_tokens: 0,
418            cached_input_tokens: 0,
419            cache_creation_input_tokens: 0,
420            reasoning_tokens: 0,
421        }
422    }
423}
424
425impl Default for Usage {
426    fn default() -> Self {
427        Self::new()
428    }
429}
430
431impl Add for Usage {
432    type Output = Self;
433
434    fn add(self, other: Self) -> Self::Output {
435        Self {
436            input_tokens: self.input_tokens + other.input_tokens,
437            output_tokens: self.output_tokens + other.output_tokens,
438            total_tokens: self.total_tokens + other.total_tokens,
439            cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
440            cache_creation_input_tokens: self.cache_creation_input_tokens
441                + other.cache_creation_input_tokens,
442            reasoning_tokens: self.reasoning_tokens + other.reasoning_tokens,
443        }
444    }
445}
446
447impl AddAssign for Usage {
448    fn add_assign(&mut self, other: Self) {
449        self.input_tokens += other.input_tokens;
450        self.output_tokens += other.output_tokens;
451        self.total_tokens += other.total_tokens;
452        self.cached_input_tokens += other.cached_input_tokens;
453        self.cache_creation_input_tokens += other.cache_creation_input_tokens;
454        self.reasoning_tokens += other.reasoning_tokens;
455    }
456}
457
458/// Trait defining a completion model that can be used to generate completion responses.
459/// This trait is meant to be implemented by the user to define a custom completion model,
460/// either from a third party provider (e.g.: OpenAI) or a local model.
461pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
462    /// The raw response type returned by the underlying completion model.
463    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
464    /// The raw response type returned by the underlying completion model when streaming.
465    type StreamingResponse: Clone
466        + Unpin
467        + WasmCompatSend
468        + WasmCompatSync
469        + Serialize
470        + DeserializeOwned
471        + GetTokenUsage;
472
473    /// Provider client type used to construct this model.
474    type Client;
475
476    /// Construct a model handle from a provider client and model identifier.
477    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
478
479    /// Generates a completion response for the given completion request.
480    fn completion(
481        &self,
482        request: CompletionRequest,
483    ) -> impl std::future::Future<
484        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
485    > + WasmCompatSend;
486
487    fn stream(
488        &self,
489        request: CompletionRequest,
490    ) -> impl std::future::Future<
491        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
492    > + WasmCompatSend;
493
494    /// Generates a completion request builder for the given `prompt`.
495    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
496        CompletionRequestBuilder::new(self.clone(), prompt)
497    }
498}
499
500/// Struct representing a general completion request that can be sent to a completion model provider.
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct CompletionRequest {
503    /// Optional model override for this request.
504    pub model: Option<String>,
505    /// Legacy preamble field preserved for backwards compatibility.
506    ///
507    /// New code should prefer a leading [`Message::System`]
508    /// in `chat_history` as the canonical representation of system instructions.
509    pub preamble: Option<String>,
510    /// The chat history to be sent to the completion model provider.
511    /// The very last message will always be the prompt (hence why there is *always* one)
512    pub chat_history: OneOrMany<Message>,
513    /// The documents to be sent to the completion model provider
514    pub documents: Vec<Document>,
515    /// The tools to be sent to the completion model provider
516    pub tools: Vec<ToolDefinition>,
517    /// The temperature to be sent to the completion model provider
518    pub temperature: Option<f64>,
519    /// The max tokens to be sent to the completion model provider
520    pub max_tokens: Option<u64>,
521    /// Whether tools are required to be used by the model provider or not before providing a response.
522    pub tool_choice: Option<ToolChoice>,
523    /// Additional provider-specific parameters to be sent to the completion model provider
524    pub additional_params: Option<serde_json::Value>,
525    /// Optional JSON Schema for structured output. When set, providers that support
526    /// native structured outputs will constrain the model's response to match this schema.
527    pub output_schema: Option<schemars::Schema>,
528}
529
530impl CompletionRequest {
531    /// Extracts a name from the output schema's `"title"` field, falling back to `"response_schema"`.
532    /// Useful for providers that require a name alongside the JSON Schema (e.g., OpenAI).
533    pub fn output_schema_name(&self) -> Option<String> {
534        self.output_schema.as_ref().map(|schema| {
535            schema
536                .as_object()
537                .and_then(|o| o.get("title"))
538                .and_then(|v| v.as_str())
539                .unwrap_or("response_schema")
540                .to_string()
541        })
542    }
543
544    /// Returns documents normalized into a message (if any).
545    /// Most providers do not accept documents directly as input, so it needs to convert into a
546    /// `Message` so that it can be incorporated into `chat_history`.
547    pub fn normalized_documents(&self) -> Option<Message> {
548        if self.documents.is_empty() {
549            return None;
550        }
551
552        // Most providers will convert documents into a text unless it can handle document messages.
553        // We use `UserContent::document` for those who handle it directly!
554        let messages = self
555            .documents
556            .iter()
557            .map(|doc| {
558                UserContent::document(
559                    doc.to_string(),
560                    // In the future, we can customize `Document` to pass these extra types through.
561                    // Most providers ditch these but they might want to use them.
562                    Some(DocumentMediaType::TXT),
563                )
564            })
565            .collect::<Vec<_>>();
566
567        OneOrMany::from_iter_optional(messages).map(|content| Message::User { content })
568    }
569
570    /// Adds a provider-hosted tool by storing it in `additional_params.tools`.
571    pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
572        self.additional_params =
573            merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
574        self
575    }
576
577    /// Adds provider-hosted tools by storing them in `additional_params.tools`.
578    pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
579        self.additional_params =
580            merge_provider_tools_into_additional_params(self.additional_params, tools);
581        self
582    }
583}
584
585fn merge_provider_tools_into_additional_params(
586    additional_params: Option<serde_json::Value>,
587    provider_tools: Vec<ProviderToolDefinition>,
588) -> Option<serde_json::Value> {
589    if provider_tools.is_empty() {
590        return additional_params;
591    }
592
593    let mut provider_tools_json = provider_tools
594        .into_iter()
595        .map(|ProviderToolDefinition { kind, mut config }| {
596            // Force the provider tool type from the strongly-typed field.
597            config.insert("type".to_string(), serde_json::Value::String(kind));
598            serde_json::Value::Object(config)
599        })
600        .collect::<Vec<_>>();
601
602    let mut params_map = match additional_params {
603        Some(serde_json::Value::Object(map)) => map,
604        Some(serde_json::Value::Bool(stream)) => {
605            let mut map = serde_json::Map::new();
606            map.insert("stream".to_string(), serde_json::Value::Bool(stream));
607            map
608        }
609        _ => serde_json::Map::new(),
610    };
611
612    let mut merged_tools = match params_map.remove("tools") {
613        Some(serde_json::Value::Array(existing)) => existing,
614        _ => Vec::new(),
615    };
616    merged_tools.append(&mut provider_tools_json);
617    params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
618    Some(serde_json::Value::Object(params_map))
619}
620
621/// Builder struct for constructing a completion request.
622///
623/// Example usage:
624/// ```no_run
625/// use rig_core::{
626///     client::CompletionClient,
627///     providers::openai::{Client, self},
628///     completion::{CompletionModel, CompletionRequestBuilder},
629/// };
630///
631/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
632/// let openai = Client::new("your-openai-api-key")?;
633/// let model = openai.completion_model(openai::GPT_5_2);
634///
635/// // Create the completion request and execute it separately
636/// let request = CompletionRequestBuilder::new(model.clone(), "Who are you?".to_string())
637///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
638///     .temperature(0.5)
639///     .build();
640///
641/// let response = model.completion(request).await?;
642/// # Ok(())
643/// # }
644/// ```
645///
646/// Alternatively, you can execute the completion request directly from the builder:
647/// ```no_run
648/// use rig_core::{
649///     client::CompletionClient,
650///     providers::openai::{Client, self},
651///     completion::CompletionRequestBuilder,
652/// };
653///
654/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
655/// let openai = Client::new("your-openai-api-key")?;
656/// let model = openai.completion_model(openai::GPT_5_2);
657///
658/// // Create the completion request and execute it directly
659/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
660///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
661///     .temperature(0.5)
662///     .send()
663///     .await?;
664/// # Ok(())
665/// # }
666/// ```
667///
668/// Note: It is usually unnecessary to create a completion request builder directly.
669/// Instead, use the [CompletionModel::completion_request] method.
670pub struct CompletionRequestBuilder<M: CompletionModel> {
671    model: M,
672    prompt: Message,
673    request_model: Option<String>,
674    preamble: Option<String>,
675    chat_history: Vec<Message>,
676    documents: Vec<Document>,
677    tools: Vec<ToolDefinition>,
678    provider_tools: Vec<ProviderToolDefinition>,
679    temperature: Option<f64>,
680    max_tokens: Option<u64>,
681    tool_choice: Option<ToolChoice>,
682    additional_params: Option<serde_json::Value>,
683    output_schema: Option<schemars::Schema>,
684}
685
686impl<M: CompletionModel> CompletionRequestBuilder<M> {
687    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
688        Self {
689            model,
690            prompt: prompt.into(),
691            request_model: None,
692            preamble: None,
693            chat_history: Vec::new(),
694            documents: Vec::new(),
695            tools: Vec::new(),
696            provider_tools: Vec::new(),
697            temperature: None,
698            max_tokens: None,
699            tool_choice: None,
700            additional_params: None,
701            output_schema: None,
702        }
703    }
704
705    /// Sets the preamble for the completion request.
706    pub fn preamble(mut self, preamble: String) -> Self {
707        // Legacy public API: funnel preamble into canonical system messages at build-time.
708        self.preamble = Some(preamble);
709        self
710    }
711
712    /// Overrides the model used for this request.
713    pub fn model(mut self, model: impl Into<String>) -> Self {
714        self.request_model = Some(model.into());
715        self
716    }
717
718    /// Overrides the model used for this request.
719    pub fn model_opt(mut self, model: Option<String>) -> Self {
720        self.request_model = model;
721        self
722    }
723
724    pub fn without_preamble(mut self) -> Self {
725        self.preamble = None;
726        self
727    }
728
729    /// Adds a message to the chat history for the completion request.
730    pub fn message(mut self, message: Message) -> Self {
731        self.chat_history.push(message);
732
733        self
734    }
735
736    /// Adds a list of messages to the chat history for the completion request.
737    pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
738        self.chat_history.extend(messages);
739
740        self
741    }
742
743    /// Adds a document to the completion request.
744    pub fn document(mut self, document: Document) -> Self {
745        self.documents.push(document);
746        self
747    }
748
749    /// Adds a list of documents to the completion request.
750    pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
751        documents
752            .into_iter()
753            .fold(self, |builder, doc| builder.document(doc))
754    }
755
756    /// Adds a tool to the completion request.
757    pub fn tool(mut self, tool: ToolDefinition) -> Self {
758        self.tools.push(tool);
759        self
760    }
761
762    /// Adds a list of tools to the completion request.
763    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
764        tools
765            .into_iter()
766            .fold(self, |builder, tool| builder.tool(tool))
767    }
768
769    /// Adds a provider-hosted tool to the completion request.
770    pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
771        self.provider_tools.push(tool);
772        self
773    }
774
775    /// Adds provider-hosted tools to the completion request.
776    pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
777        tools
778            .into_iter()
779            .fold(self, |builder, tool| builder.provider_tool(tool))
780    }
781
782    /// Adds additional parameters to the completion request.
783    /// This can be used to set additional provider-specific parameters. For example,
784    /// Cohere's completion models accept a `connectors` parameter that can be used to
785    /// specify the data connectors used by Cohere when executing the completion
786    /// (see `examples/cohere_connectors.rs`).
787    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
788        match self.additional_params {
789            Some(params) => {
790                self.additional_params = Some(json_utils::merge(params, additional_params));
791            }
792            None => {
793                self.additional_params = Some(additional_params);
794            }
795        }
796        self
797    }
798
799    /// Sets the additional parameters for the completion request.
800    /// This can be used to set additional provider-specific parameters. For example,
801    /// Cohere's completion models accept a `connectors` parameter that can be used to
802    /// specify the data connectors used by Cohere when executing the completion
803    /// (see `examples/cohere_connectors.rs`).
804    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
805        self.additional_params = additional_params;
806        self
807    }
808
809    /// Sets the temperature for the completion request.
810    pub fn temperature(mut self, temperature: f64) -> Self {
811        self.temperature = Some(temperature);
812        self
813    }
814
815    /// Sets the temperature for the completion request.
816    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
817        self.temperature = temperature;
818        self
819    }
820
821    /// Sets the max tokens for the completion request.
822    /// Note: This is required if using Anthropic
823    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
824        self.max_tokens = Some(max_tokens);
825        self
826    }
827
828    /// Sets the max tokens for the completion request.
829    /// Note: This is required if using Anthropic
830    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
831        self.max_tokens = max_tokens;
832        self
833    }
834
835    /// Sets the thing.
836    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
837        self.tool_choice = Some(tool_choice);
838        self
839    }
840
841    /// Sets the output schema for structured output. When set, providers that support
842    /// native structured outputs will constrain the model's response to match this schema.
843    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
844    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
845    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
846    /// to still be able to leverage structured outputs.
847    pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
848        self.output_schema = Some(schema);
849        self
850    }
851
852    /// Sets the output schema for structured output from an optional value.
853    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
854    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
855    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
856    /// to still be able to leverage structured outputs.
857    pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
858        self.output_schema = schema;
859        self
860    }
861
862    /// Builds the completion request.
863    pub fn build(self) -> CompletionRequest {
864        // Build the final message list, prepending preamble if present
865        let mut chat_history = self.chat_history;
866        let prompt = self.prompt;
867        if let Some(preamble) = self.preamble {
868            chat_history.insert(0, Message::system(preamble));
869        }
870        chat_history.push(prompt.clone());
871
872        let chat_history =
873            OneOrMany::from_iter_optional(chat_history).unwrap_or_else(|| OneOrMany::one(prompt));
874        let additional_params = merge_provider_tools_into_additional_params(
875            self.additional_params,
876            self.provider_tools,
877        );
878
879        CompletionRequest {
880            model: self.request_model,
881            preamble: None,
882            chat_history,
883            documents: self.documents,
884            tools: self.tools,
885            temperature: self.temperature,
886            max_tokens: self.max_tokens,
887            tool_choice: self.tool_choice,
888            additional_params,
889            output_schema: self.output_schema,
890        }
891    }
892
893    /// Sends the completion request to the completion model provider and returns the completion response.
894    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
895        let model = self.model.clone();
896        model.completion(self.build()).await
897    }
898
899    /// Stream the completion request
900    pub async fn stream<'a>(
901        self,
902    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
903    where
904        <M as CompletionModel>::StreamingResponse: 'a,
905        Self: 'a,
906    {
907        let model = self.model.clone();
908        model.stream(self.build()).await
909    }
910}
911
912#[cfg(test)]
913mod tests {
914
915    use super::*;
916    use crate::test_utils::MockCompletionModel;
917
918    #[test]
919    fn test_document_display_without_metadata() {
920        let doc = Document {
921            id: "123".to_string(),
922            text: "This is a test document.".to_string(),
923            additional_props: HashMap::new(),
924        };
925
926        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
927        assert_eq!(format!("{doc}"), expected);
928    }
929
930    #[test]
931    fn test_document_display_with_metadata() {
932        let mut additional_props = HashMap::new();
933        additional_props.insert("author".to_string(), "John Doe".to_string());
934        additional_props.insert("length".to_string(), "42".to_string());
935
936        let doc = Document {
937            id: "123".to_string(),
938            text: "This is a test document.".to_string(),
939            additional_props,
940        };
941
942        let expected = concat!(
943            "<file id: 123>\n",
944            "<metadata author: \"John Doe\" length: \"42\" />\n",
945            "This is a test document.\n",
946            "</file>\n"
947        );
948        assert_eq!(format!("{doc}"), expected);
949    }
950
951    #[test]
952    fn test_normalize_documents_with_documents() {
953        let doc1 = Document {
954            id: "doc1".to_string(),
955            text: "Document 1 text.".to_string(),
956            additional_props: HashMap::new(),
957        };
958
959        let doc2 = Document {
960            id: "doc2".to_string(),
961            text: "Document 2 text.".to_string(),
962            additional_props: HashMap::new(),
963        };
964
965        let request = CompletionRequest {
966            model: None,
967            preamble: None,
968            chat_history: OneOrMany::one("What is the capital of France?".into()),
969            documents: vec![doc1, doc2],
970            tools: Vec::new(),
971            temperature: None,
972            max_tokens: None,
973            tool_choice: None,
974            additional_params: None,
975            output_schema: None,
976        };
977
978        let expected = Message::User {
979            content: OneOrMany::many(vec![
980                UserContent::document(
981                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
982                    Some(DocumentMediaType::TXT),
983                ),
984                UserContent::document(
985                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
986                    Some(DocumentMediaType::TXT),
987                ),
988            ])
989            .expect("There will be at least one document"),
990        };
991
992        assert_eq!(request.normalized_documents(), Some(expected));
993    }
994
995    #[test]
996    fn test_normalize_documents_without_documents() {
997        let request = CompletionRequest {
998            model: None,
999            preamble: None,
1000            chat_history: OneOrMany::one("What is the capital of France?".into()),
1001            documents: Vec::new(),
1002            tools: Vec::new(),
1003            temperature: None,
1004            max_tokens: None,
1005            tool_choice: None,
1006            additional_params: None,
1007            output_schema: None,
1008        };
1009
1010        assert_eq!(request.normalized_documents(), None);
1011    }
1012
1013    #[test]
1014    fn preamble_builder_funnels_to_system_message() {
1015        let request =
1016            CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1017                .preamble("System prompt".to_string())
1018                .message(Message::user("History"))
1019                .build();
1020
1021        assert_eq!(request.preamble, None);
1022
1023        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1024        assert_eq!(history.len(), 3);
1025        assert!(matches!(
1026            &history[0],
1027            Message::System { content } if content == "System prompt"
1028        ));
1029        assert!(matches!(&history[1], Message::User { .. }));
1030        assert!(matches!(&history[2], Message::User { .. }));
1031    }
1032
1033    #[test]
1034    fn without_preamble_removes_legacy_preamble_injection() {
1035        let request =
1036            CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1037                .preamble("System prompt".to_string())
1038                .without_preamble()
1039                .build();
1040
1041        assert_eq!(request.preamble, None);
1042        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1043        assert_eq!(history.len(), 1);
1044        assert!(matches!(&history[0], Message::User { .. }));
1045    }
1046}