Skip to main content

rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, DocumentMediaType};
67use crate::message::ToolChoice;
68use crate::streaming::StreamingCompletionResponse;
69use crate::tool::server::ToolServerError;
70use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
71use crate::{OneOrMany, http_client};
72use crate::{
73    json_utils,
74    message::{Message, UserContent},
75    tool::ToolSetError,
76};
77use serde::de::DeserializeOwned;
78use serde::{Deserialize, Serialize};
79use std::collections::HashMap;
80use std::ops::{Add, AddAssign};
81use thiserror::Error;
82
83// Errors
84#[derive(Debug, Error)]
85pub enum CompletionError {
86    /// Http error (e.g.: connection error, timeout, etc.)
87    #[error("HttpError: {0}")]
88    HttpError(#[from] http_client::Error),
89
90    /// Json error (e.g.: serialization, deserialization)
91    #[error("JsonError: {0}")]
92    JsonError(#[from] serde_json::Error),
93
94    /// Url error (e.g.: invalid URL)
95    #[error("UrlError: {0}")]
96    UrlError(#[from] url::ParseError),
97
98    #[cfg(not(target_family = "wasm"))]
99    /// Error building the completion request
100    #[error("RequestError: {0}")]
101    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
102
103    #[cfg(target_family = "wasm")]
104    /// Error building the completion request
105    #[error("RequestError: {0}")]
106    RequestError(#[from] Box<dyn std::error::Error + 'static>),
107
108    /// Error parsing the completion response
109    #[error("ResponseError: {0}")]
110    ResponseError(String),
111
112    /// Error returned by the completion model provider
113    #[error("ProviderError: {0}")]
114    ProviderError(String),
115}
116
117/// Prompt errors
118#[derive(Debug, Error)]
119pub enum PromptError {
120    /// Something went wrong with the completion
121    #[error("CompletionError: {0}")]
122    CompletionError(#[from] CompletionError),
123
124    /// There was an error while using a tool
125    #[error("ToolCallError: {0}")]
126    ToolError(#[from] ToolSetError),
127
128    /// There was an issue while executing a tool on a tool server
129    #[error("ToolServerError: {0}")]
130    ToolServerError(#[from] Box<ToolServerError>),
131
132    /// The LLM tried to call too many tools during a multi-turn conversation.
133    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
134    /// or increase the amount of turns given in `.multi_turn()`.
135    #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
136    MaxTurnsError {
137        max_turns: usize,
138        chat_history: Box<Vec<Message>>,
139        prompt: Box<Message>,
140    },
141
142    /// A prompting loop was cancelled.
143    #[error("PromptCancelled: {reason}")]
144    PromptCancelled {
145        chat_history: Vec<Message>,
146        reason: String,
147    },
148}
149
150impl PromptError {
151    pub(crate) fn prompt_cancelled(
152        chat_history: impl IntoIterator<Item = Message>,
153        reason: impl Into<String>,
154    ) -> Self {
155        Self::PromptCancelled {
156            chat_history: chat_history.into_iter().collect(),
157            reason: reason.into(),
158        }
159    }
160}
161
162/// Errors that can occur when using typed structured output via [`TypedPrompt::prompt_typed`].
163#[derive(Debug, Error)]
164pub enum StructuredOutputError {
165    /// An error occurred during the prompt execution.
166    #[error("PromptError: {0}")]
167    PromptError(#[from] Box<PromptError>),
168
169    /// Failed to deserialize the model's response into the target type.
170    #[error("DeserializationError: {0}")]
171    DeserializationError(#[from] serde_json::Error),
172
173    /// The model returned an empty response.
174    #[error("EmptyResponse: model returned no content")]
175    EmptyResponse,
176}
177
178#[derive(Clone, Debug, Deserialize, Serialize)]
179pub struct Document {
180    pub id: String,
181    pub text: String,
182    #[serde(flatten)]
183    pub additional_props: HashMap<String, String>,
184}
185
186impl std::fmt::Display for Document {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        write!(
189            f,
190            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
191            self.id,
192            if self.additional_props.is_empty() {
193                self.text.clone()
194            } else {
195                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
196                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
197                let metadata = sorted_props
198                    .iter()
199                    .map(|(k, v)| format!("{k}: {v:?}"))
200                    .collect::<Vec<_>>()
201                    .join(" ");
202                format!("<metadata {} />\n{}", metadata, self.text)
203            }
204        )
205    }
206}
207
208#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
209pub struct ToolDefinition {
210    pub name: String,
211    pub description: String,
212    pub parameters: serde_json::Value,
213}
214
215/// Provider-native tool definition.
216///
217/// Stored under `additional_params.tools` and forwarded by providers that support
218/// provider-managed tools.
219#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
220pub struct ProviderToolDefinition {
221    /// Tool type/kind name as expected by the target provider (for example `web_search`).
222    #[serde(rename = "type")]
223    pub kind: String,
224    /// Additional provider-specific configuration for this hosted tool.
225    #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
226    pub config: serde_json::Map<String, serde_json::Value>,
227}
228
229impl ProviderToolDefinition {
230    /// Creates a provider-hosted tool definition by type.
231    pub fn new(kind: impl Into<String>) -> Self {
232        Self {
233            kind: kind.into(),
234            config: serde_json::Map::new(),
235        }
236    }
237
238    /// Adds a provider-specific configuration key/value.
239    pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
240        self.config.insert(key.into(), value);
241        self
242    }
243}
244
245// ================================================================
246// Implementations
247// ================================================================
248/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
249pub trait Prompt: WasmCompatSend + WasmCompatSync {
250    /// Send a simple prompt to the underlying completion model.
251    ///
252    /// If the completion model's response is a message, then it is returned as a string.
253    ///
254    /// If the completion model's response is a tool call, then the tool is called and
255    /// the result is returned as a string.
256    ///
257    /// If the tool does not exist, or the tool call fails, then an error is returned.
258    fn prompt(
259        &self,
260        prompt: impl Into<Message> + WasmCompatSend,
261    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
262}
263
264/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
265pub trait Chat: WasmCompatSend + WasmCompatSync {
266    /// Send a prompt with optional chat history to the underlying completion model.
267    ///
268    /// If the completion model's response is a message, then it is returned as a string.
269    ///
270    /// If the completion model's response is a tool call, then the tool is called and the result
271    /// is returned as a string.
272    ///
273    /// If the tool does not exist, or the tool call fails, then an error is returned.
274    fn chat<I, T>(
275        &self,
276        prompt: impl Into<Message> + WasmCompatSend,
277        chat_history: I,
278    ) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend
279    where
280        I: IntoIterator<Item = T> + WasmCompatSend,
281        T: Into<Message>;
282}
283
284/// Trait defining a high-level typed prompt interface for structured output.
285///
286/// This trait provides an ergonomic way to get typed responses from an LLM by automatically
287/// generating a JSON schema from the target type and deserializing the response.
288///
289/// # Example
290/// ```rust,ignore
291/// use rig::prelude::*;
292/// use schemars::JsonSchema;
293/// use serde::Deserialize;
294///
295/// #[derive(Debug, Deserialize, JsonSchema)]
296/// struct WeatherForecast {
297///     city: String,
298///     temperature_f: f64,
299///     conditions: String,
300/// }
301///
302/// let agent = client.agent("gpt-4o").build();
303/// let forecast: WeatherForecast = agent
304///     .prompt_typed("What's the weather in NYC?")
305///     .await?;
306/// ```
307pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
308    /// The type of the typed prompt request returned by `prompt_typed`.
309    type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
310    where
311        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
312
313    /// Send a prompt and receive a typed structured response.
314    ///
315    /// The JSON schema for `T` is automatically generated and sent to the provider.
316    /// Providers that support native structured outputs will constrain the model's
317    /// response to match this schema.
318    ///
319    /// # Type Parameters
320    /// * `T` - The target type to deserialize the response into. Must implement
321    ///   `JsonSchema` (for schema generation), `DeserializeOwned` (for deserialization),
322    ///   and `WasmCompatSend` (for async compatibility).
323    ///
324    /// # Example
325    /// ```rust,ignore
326    /// // Type can be inferred
327    /// let forecast: WeatherForecast = agent.prompt_typed("What's the weather?").await?;
328    ///
329    /// // Or specified explicitly with turbofish
330    /// let forecast = agent.prompt_typed::<WeatherForecast>("What's the weather?").await?;
331    /// ```
332    fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
333    where
334        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
335}
336
337/// Trait defining a low-level LLM completion interface
338pub trait Completion<M: CompletionModel> {
339    /// Generates a completion request builder for the given `prompt` and `chat_history`.
340    /// This function is meant to be called by the user to further customize the
341    /// request at prompt time before sending it.
342    ///
343    /// ❗IMPORTANT: The type that implements this trait might have already
344    /// populated fields in the builder (the exact fields depend on the type).
345    /// For fields that have already been set by the model, calling the corresponding
346    /// method on the builder will overwrite the value set by the model.
347    ///
348    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
349    /// contain the `preamble` provided when creating the agent.
350    fn completion<I, T>(
351        &self,
352        prompt: impl Into<Message> + WasmCompatSend,
353        chat_history: I,
354    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
355    + WasmCompatSend
356    where
357        I: IntoIterator<Item = T> + WasmCompatSend,
358        T: Into<Message>;
359}
360
361/// General completion response struct that contains the high-level completion choice
362/// and the raw response. The completion choice contains one or more assistant content.
363#[derive(Debug)]
364pub struct CompletionResponse<T> {
365    /// The completion choice (represented by one or more assistant message content)
366    /// returned by the completion model provider
367    pub choice: OneOrMany<AssistantContent>,
368    /// Tokens used during prompting and responding
369    pub usage: Usage,
370    /// The raw response returned by the completion model provider
371    pub raw_response: T,
372    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
373    /// Used to pair reasoning input items with their output items in multi-turn.
374    pub message_id: Option<String>,
375}
376
377/// A trait for grabbing the token usage of a completion response.
378///
379/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
380pub trait GetTokenUsage {
381    fn token_usage(&self) -> Option<crate::completion::Usage>;
382}
383
384impl GetTokenUsage for () {
385    fn token_usage(&self) -> Option<crate::completion::Usage> {
386        None
387    }
388}
389
390impl<T> GetTokenUsage for Option<T>
391where
392    T: GetTokenUsage,
393{
394    fn token_usage(&self) -> Option<crate::completion::Usage> {
395        if let Some(usage) = self {
396            usage.token_usage()
397        } else {
398            None
399        }
400    }
401}
402
403/// Struct representing the token usage for a completion request.
404/// If tokens used are `0`, then the provider failed to supply token usage metrics.
405#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
406pub struct Usage {
407    /// The number of input ("prompt") tokens used in a given request.
408    pub input_tokens: u64,
409    /// The number of output ("completion") tokens used in a given request.
410    pub output_tokens: u64,
411    /// We store this separately as some providers may only report one number
412    pub total_tokens: u64,
413    /// The number of input tokens read from a provider-managed cache
414    pub cached_input_tokens: u64,
415    /// The number of input tokens written to a provider-managed cache
416    pub cache_creation_input_tokens: u64,
417}
418
419impl Usage {
420    /// Creates a new instance of `Usage`.
421    pub fn new() -> Self {
422        Self {
423            input_tokens: 0,
424            output_tokens: 0,
425            total_tokens: 0,
426            cached_input_tokens: 0,
427            cache_creation_input_tokens: 0,
428        }
429    }
430}
431
432impl Default for Usage {
433    fn default() -> Self {
434        Self::new()
435    }
436}
437
438impl Add for Usage {
439    type Output = Self;
440
441    fn add(self, other: Self) -> Self::Output {
442        Self {
443            input_tokens: self.input_tokens + other.input_tokens,
444            output_tokens: self.output_tokens + other.output_tokens,
445            total_tokens: self.total_tokens + other.total_tokens,
446            cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
447            cache_creation_input_tokens: self.cache_creation_input_tokens
448                + other.cache_creation_input_tokens,
449        }
450    }
451}
452
453impl AddAssign for Usage {
454    fn add_assign(&mut self, other: Self) {
455        self.input_tokens += other.input_tokens;
456        self.output_tokens += other.output_tokens;
457        self.total_tokens += other.total_tokens;
458        self.cached_input_tokens += other.cached_input_tokens;
459        self.cache_creation_input_tokens += other.cache_creation_input_tokens;
460    }
461}
462
463/// Trait defining a completion model that can be used to generate completion responses.
464/// This trait is meant to be implemented by the user to define a custom completion model,
465/// either from a third party provider (e.g.: OpenAI) or a local model.
466pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
467    /// The raw response type returned by the underlying completion model.
468    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
469    /// The raw response type returned by the underlying completion model when streaming.
470    type StreamingResponse: Clone
471        + Unpin
472        + WasmCompatSend
473        + WasmCompatSync
474        + Serialize
475        + DeserializeOwned
476        + GetTokenUsage;
477
478    type Client;
479
480    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
481
482    /// Generates a completion response for the given completion request.
483    fn completion(
484        &self,
485        request: CompletionRequest,
486    ) -> impl std::future::Future<
487        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
488    > + WasmCompatSend;
489
490    fn stream(
491        &self,
492        request: CompletionRequest,
493    ) -> impl std::future::Future<
494        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
495    > + WasmCompatSend;
496
497    /// Generates a completion request builder for the given `prompt`.
498    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
499        CompletionRequestBuilder::new(self.clone(), prompt)
500    }
501}
502
503/// Struct representing a general completion request that can be sent to a completion model provider.
504#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct CompletionRequest {
506    /// Optional model override for this request.
507    pub model: Option<String>,
508    /// Legacy preamble field preserved for backwards compatibility.
509    ///
510    /// New code should prefer a leading [`Message::System`]
511    /// in `chat_history` as the canonical representation of system instructions.
512    pub preamble: Option<String>,
513    /// The chat history to be sent to the completion model provider.
514    /// The very last message will always be the prompt (hence why there is *always* one)
515    pub chat_history: OneOrMany<Message>,
516    /// The documents to be sent to the completion model provider
517    pub documents: Vec<Document>,
518    /// The tools to be sent to the completion model provider
519    pub tools: Vec<ToolDefinition>,
520    /// The temperature to be sent to the completion model provider
521    pub temperature: Option<f64>,
522    /// The max tokens to be sent to the completion model provider
523    pub max_tokens: Option<u64>,
524    /// Whether tools are required to be used by the model provider or not before providing a response.
525    pub tool_choice: Option<ToolChoice>,
526    /// Additional provider-specific parameters to be sent to the completion model provider
527    pub additional_params: Option<serde_json::Value>,
528    /// Optional JSON Schema for structured output. When set, providers that support
529    /// native structured outputs will constrain the model's response to match this schema.
530    pub output_schema: Option<schemars::Schema>,
531}
532
533impl CompletionRequest {
534    /// Extracts a name from the output schema's `"title"` field, falling back to `"response_schema"`.
535    /// Useful for providers that require a name alongside the JSON Schema (e.g., OpenAI).
536    pub fn output_schema_name(&self) -> Option<String> {
537        self.output_schema.as_ref().map(|schema| {
538            schema
539                .as_object()
540                .and_then(|o| o.get("title"))
541                .and_then(|v| v.as_str())
542                .unwrap_or("response_schema")
543                .to_string()
544        })
545    }
546
547    /// Returns documents normalized into a message (if any).
548    /// Most providers do not accept documents directly as input, so it needs to convert into a
549    ///  `Message` so that it can be incorporated into `chat_history` as a
550    pub fn normalized_documents(&self) -> Option<Message> {
551        if self.documents.is_empty() {
552            return None;
553        }
554
555        // Most providers will convert documents into a text unless it can handle document messages.
556        // We use `UserContent::document` for those who handle it directly!
557        let messages = self
558            .documents
559            .iter()
560            .map(|doc| {
561                UserContent::document(
562                    doc.to_string(),
563                    // In the future, we can customize `Document` to pass these extra types through.
564                    // Most providers ditch these but they might want to use them.
565                    Some(DocumentMediaType::TXT),
566                )
567            })
568            .collect::<Vec<_>>();
569
570        OneOrMany::from_iter_optional(messages).map(|content| Message::User { content })
571    }
572
573    /// Adds a provider-hosted tool by storing it in `additional_params.tools`.
574    pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
575        self.additional_params =
576            merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
577        self
578    }
579
580    /// Adds provider-hosted tools by storing them in `additional_params.tools`.
581    pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
582        self.additional_params =
583            merge_provider_tools_into_additional_params(self.additional_params, tools);
584        self
585    }
586}
587
588fn merge_provider_tools_into_additional_params(
589    additional_params: Option<serde_json::Value>,
590    provider_tools: Vec<ProviderToolDefinition>,
591) -> Option<serde_json::Value> {
592    if provider_tools.is_empty() {
593        return additional_params;
594    }
595
596    let mut provider_tools_json = provider_tools
597        .into_iter()
598        .map(|ProviderToolDefinition { kind, mut config }| {
599            // Force the provider tool type from the strongly-typed field.
600            config.insert("type".to_string(), serde_json::Value::String(kind));
601            serde_json::Value::Object(config)
602        })
603        .collect::<Vec<_>>();
604
605    let mut params_map = match additional_params {
606        Some(serde_json::Value::Object(map)) => map,
607        Some(serde_json::Value::Bool(stream)) => {
608            let mut map = serde_json::Map::new();
609            map.insert("stream".to_string(), serde_json::Value::Bool(stream));
610            map
611        }
612        _ => serde_json::Map::new(),
613    };
614
615    let mut merged_tools = match params_map.remove("tools") {
616        Some(serde_json::Value::Array(existing)) => existing,
617        _ => Vec::new(),
618    };
619    merged_tools.append(&mut provider_tools_json);
620    params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
621    Some(serde_json::Value::Object(params_map))
622}
623
624/// Builder struct for constructing a completion request.
625///
626/// Example usage:
627/// ```rust
628/// use rig::{
629///     providers::openai::{Client, self},
630///     completion::CompletionRequestBuilder,
631/// };
632///
633/// let openai = Client::new("your-openai-api-key");
634/// let model = openai.completion_model(openai::GPT_4O).build();
635///
636/// // Create the completion request and execute it separately
637/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
638///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
639///     .temperature(0.5)
640///     .build();
641///
642/// let response = model.completion(request)
643///     .await
644///     .expect("Failed to get completion response");
645/// ```
646///
647/// Alternatively, you can execute the completion request directly from the builder:
648/// ```rust
649/// use rig::{
650///     providers::openai::{Client, self},
651///     completion::CompletionRequestBuilder,
652/// };
653///
654/// let openai = Client::new("your-openai-api-key");
655/// let model = openai.completion_model(openai::GPT_4O).build();
656///
657/// // Create the completion request and execute it directly
658/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
659///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
660///     .temperature(0.5)
661///     .send()
662///     .await
663///     .expect("Failed to get completion response");
664/// ```
665///
666/// Note: It is usually unnecessary to create a completion request builder directly.
667/// Instead, use the [CompletionModel::completion_request] method.
668pub struct CompletionRequestBuilder<M: CompletionModel> {
669    model: M,
670    prompt: Message,
671    request_model: Option<String>,
672    preamble: Option<String>,
673    chat_history: Vec<Message>,
674    documents: Vec<Document>,
675    tools: Vec<ToolDefinition>,
676    provider_tools: Vec<ProviderToolDefinition>,
677    temperature: Option<f64>,
678    max_tokens: Option<u64>,
679    tool_choice: Option<ToolChoice>,
680    additional_params: Option<serde_json::Value>,
681    output_schema: Option<schemars::Schema>,
682}
683
684impl<M: CompletionModel> CompletionRequestBuilder<M> {
685    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
686        Self {
687            model,
688            prompt: prompt.into(),
689            request_model: None,
690            preamble: None,
691            chat_history: Vec::new(),
692            documents: Vec::new(),
693            tools: Vec::new(),
694            provider_tools: Vec::new(),
695            temperature: None,
696            max_tokens: None,
697            tool_choice: None,
698            additional_params: None,
699            output_schema: None,
700        }
701    }
702
703    /// Sets the preamble for the completion request.
704    pub fn preamble(mut self, preamble: String) -> Self {
705        // Legacy public API: funnel preamble into canonical system messages at build-time.
706        self.preamble = Some(preamble);
707        self
708    }
709
710    /// Overrides the model used for this request.
711    pub fn model(mut self, model: impl Into<String>) -> Self {
712        self.request_model = Some(model.into());
713        self
714    }
715
716    /// Overrides the model used for this request.
717    pub fn model_opt(mut self, model: Option<String>) -> Self {
718        self.request_model = model;
719        self
720    }
721
722    pub fn without_preamble(mut self) -> Self {
723        self.preamble = None;
724        self
725    }
726
727    /// Adds a message to the chat history for the completion request.
728    pub fn message(mut self, message: Message) -> Self {
729        self.chat_history.push(message);
730
731        self
732    }
733
734    /// Adds a list of messages to the chat history for the completion request.
735    pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
736        self.chat_history.extend(messages);
737
738        self
739    }
740
741    /// Adds a document to the completion request.
742    pub fn document(mut self, document: Document) -> Self {
743        self.documents.push(document);
744        self
745    }
746
747    /// Adds a list of documents to the completion request.
748    pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
749        documents
750            .into_iter()
751            .fold(self, |builder, doc| builder.document(doc))
752    }
753
754    /// Adds a tool to the completion request.
755    pub fn tool(mut self, tool: ToolDefinition) -> Self {
756        self.tools.push(tool);
757        self
758    }
759
760    /// Adds a list of tools to the completion request.
761    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
762        tools
763            .into_iter()
764            .fold(self, |builder, tool| builder.tool(tool))
765    }
766
767    /// Adds a provider-hosted tool to the completion request.
768    pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
769        self.provider_tools.push(tool);
770        self
771    }
772
773    /// Adds provider-hosted tools to the completion request.
774    pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
775        tools
776            .into_iter()
777            .fold(self, |builder, tool| builder.provider_tool(tool))
778    }
779
780    /// Adds additional parameters to the completion request.
781    /// This can be used to set additional provider-specific parameters. For example,
782    /// Cohere's completion models accept a `connectors` parameter that can be used to
783    /// specify the data connectors used by Cohere when executing the completion
784    /// (see `examples/cohere_connectors.rs`).
785    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
786        match self.additional_params {
787            Some(params) => {
788                self.additional_params = Some(json_utils::merge(params, additional_params));
789            }
790            None => {
791                self.additional_params = Some(additional_params);
792            }
793        }
794        self
795    }
796
797    /// Sets the additional parameters for the completion request.
798    /// This can be used to set additional provider-specific parameters. For example,
799    /// Cohere's completion models accept a `connectors` parameter that can be used to
800    /// specify the data connectors used by Cohere when executing the completion
801    /// (see `examples/cohere_connectors.rs`).
802    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
803        self.additional_params = additional_params;
804        self
805    }
806
807    /// Sets the temperature for the completion request.
808    pub fn temperature(mut self, temperature: f64) -> Self {
809        self.temperature = Some(temperature);
810        self
811    }
812
813    /// Sets the temperature for the completion request.
814    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
815        self.temperature = temperature;
816        self
817    }
818
819    /// Sets the max tokens for the completion request.
820    /// Note: This is required if using Anthropic
821    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
822        self.max_tokens = Some(max_tokens);
823        self
824    }
825
826    /// Sets the max tokens for the completion request.
827    /// Note: This is required if using Anthropic
828    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
829        self.max_tokens = max_tokens;
830        self
831    }
832
833    /// Sets the thing.
834    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
835        self.tool_choice = Some(tool_choice);
836        self
837    }
838
839    /// Sets the output schema for structured output. When set, providers that support
840    /// native structured outputs will constrain the model's response to match this schema.
841    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
842    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
843    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
844    /// to still be able to leverage structured outputs.
845    pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
846        self.output_schema = Some(schema);
847        self
848    }
849
850    /// Sets the output schema for structured output from an optional value.
851    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
852    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
853    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
854    /// to still be able to leverage structured outputs.
855    pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
856        self.output_schema = schema;
857        self
858    }
859
860    /// Builds the completion request.
861    pub fn build(self) -> CompletionRequest {
862        // Build the final message list, prepending preamble if present
863        let mut chat_history = self.chat_history;
864        let prompt = self.prompt;
865        if let Some(preamble) = self.preamble {
866            chat_history.insert(0, Message::system(preamble));
867        }
868        chat_history.push(prompt.clone());
869
870        let chat_history =
871            OneOrMany::from_iter_optional(chat_history).unwrap_or_else(|| OneOrMany::one(prompt));
872        let additional_params = merge_provider_tools_into_additional_params(
873            self.additional_params,
874            self.provider_tools,
875        );
876
877        CompletionRequest {
878            model: self.request_model,
879            preamble: None,
880            chat_history,
881            documents: self.documents,
882            tools: self.tools,
883            temperature: self.temperature,
884            max_tokens: self.max_tokens,
885            tool_choice: self.tool_choice,
886            additional_params,
887            output_schema: self.output_schema,
888        }
889    }
890
891    /// Sends the completion request to the completion model provider and returns the completion response.
892    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
893        let model = self.model.clone();
894        model.completion(self.build()).await
895    }
896
897    /// Stream the completion request
898    pub async fn stream<'a>(
899        self,
900    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
901    where
902        <M as CompletionModel>::StreamingResponse: 'a,
903        Self: 'a,
904    {
905        let model = self.model.clone();
906        model.stream(self.build()).await
907    }
908}
909
910#[cfg(test)]
911mod tests {
912
913    use super::*;
914    use crate::streaming::StreamingCompletionResponse;
915    use serde::{Deserialize, Serialize};
916
917    #[derive(Clone)]
918    struct DummyModel;
919
920    #[derive(Clone, Debug, Serialize, Deserialize)]
921    struct DummyStreamingResponse;
922
923    impl GetTokenUsage for DummyStreamingResponse {
924        fn token_usage(&self) -> Option<Usage> {
925            None
926        }
927    }
928
929    impl CompletionModel for DummyModel {
930        type Response = serde_json::Value;
931        type StreamingResponse = DummyStreamingResponse;
932        type Client = ();
933
934        fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
935            Self
936        }
937
938        async fn completion(
939            &self,
940            _request: CompletionRequest,
941        ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
942            Err(CompletionError::ProviderError(
943                "dummy completion model".to_string(),
944            ))
945        }
946
947        async fn stream(
948            &self,
949            _request: CompletionRequest,
950        ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
951            Err(CompletionError::ProviderError(
952                "dummy completion model".to_string(),
953            ))
954        }
955    }
956
957    #[test]
958    fn test_document_display_without_metadata() {
959        let doc = Document {
960            id: "123".to_string(),
961            text: "This is a test document.".to_string(),
962            additional_props: HashMap::new(),
963        };
964
965        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
966        assert_eq!(format!("{doc}"), expected);
967    }
968
969    #[test]
970    fn test_document_display_with_metadata() {
971        let mut additional_props = HashMap::new();
972        additional_props.insert("author".to_string(), "John Doe".to_string());
973        additional_props.insert("length".to_string(), "42".to_string());
974
975        let doc = Document {
976            id: "123".to_string(),
977            text: "This is a test document.".to_string(),
978            additional_props,
979        };
980
981        let expected = concat!(
982            "<file id: 123>\n",
983            "<metadata author: \"John Doe\" length: \"42\" />\n",
984            "This is a test document.\n",
985            "</file>\n"
986        );
987        assert_eq!(format!("{doc}"), expected);
988    }
989
990    #[test]
991    fn test_normalize_documents_with_documents() {
992        let doc1 = Document {
993            id: "doc1".to_string(),
994            text: "Document 1 text.".to_string(),
995            additional_props: HashMap::new(),
996        };
997
998        let doc2 = Document {
999            id: "doc2".to_string(),
1000            text: "Document 2 text.".to_string(),
1001            additional_props: HashMap::new(),
1002        };
1003
1004        let request = CompletionRequest {
1005            model: None,
1006            preamble: None,
1007            chat_history: OneOrMany::one("What is the capital of France?".into()),
1008            documents: vec![doc1, doc2],
1009            tools: Vec::new(),
1010            temperature: None,
1011            max_tokens: None,
1012            tool_choice: None,
1013            additional_params: None,
1014            output_schema: None,
1015        };
1016
1017        let expected = Message::User {
1018            content: OneOrMany::many(vec![
1019                UserContent::document(
1020                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1021                    Some(DocumentMediaType::TXT),
1022                ),
1023                UserContent::document(
1024                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1025                    Some(DocumentMediaType::TXT),
1026                ),
1027            ])
1028            .expect("There will be at least one document"),
1029        };
1030
1031        assert_eq!(request.normalized_documents(), Some(expected));
1032    }
1033
1034    #[test]
1035    fn test_normalize_documents_without_documents() {
1036        let request = CompletionRequest {
1037            model: None,
1038            preamble: None,
1039            chat_history: OneOrMany::one("What is the capital of France?".into()),
1040            documents: Vec::new(),
1041            tools: Vec::new(),
1042            temperature: None,
1043            max_tokens: None,
1044            tool_choice: None,
1045            additional_params: None,
1046            output_schema: None,
1047        };
1048
1049        assert_eq!(request.normalized_documents(), None);
1050    }
1051
1052    #[test]
1053    fn preamble_builder_funnels_to_system_message() {
1054        let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1055            .preamble("System prompt".to_string())
1056            .message(Message::user("History"))
1057            .build();
1058
1059        assert_eq!(request.preamble, None);
1060
1061        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1062        assert_eq!(history.len(), 3);
1063        assert!(matches!(
1064            &history[0],
1065            Message::System { content } if content == "System prompt"
1066        ));
1067        assert!(matches!(&history[1], Message::User { .. }));
1068        assert!(matches!(&history[2], Message::User { .. }));
1069    }
1070
1071    #[test]
1072    fn without_preamble_removes_legacy_preamble_injection() {
1073        let request = CompletionRequestBuilder::new(DummyModel, Message::user("Prompt"))
1074            .preamble("System prompt".to_string())
1075            .without_preamble()
1076            .build();
1077
1078        assert_eq!(request.preamble, None);
1079        let history = request.chat_history.into_iter().collect::<Vec<_>>();
1080        assert_eq!(history.len(), 1);
1081        assert!(matches!(&history[0], Message::User { .. }));
1082    }
1083}