rig_core/completion/request.rs
1//! Completion request, response, and provider trait definitions.
2//!
3//! Most applications use [`Prompt`] or [`Chat`] through
4//! [`Agent`](crate::agent::Agent). Provider integrations implement
5//! [`CompletionModel`] and translate [`CompletionRequest`] into their native HTTP
6//! request format.
7//!
8//! # Low-level request example
9//!
10//! ```no_run
11//! use rig_core::{
12//! client::{CompletionClient, ProviderClient},
13//! completion::{AssistantContent, CompletionModel},
14//! providers::openai,
15//! };
16//!
17//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
18//! let client = openai::Client::from_env()?;
19//! let model = client.completion_model(openai::GPT_5_2);
20//!
21//! let request = model
22//! .completion_request("Who are you?")
23//! .preamble("You are a concise assistant.".to_string())
24//! .temperature(0.5)
25//! .build();
26//!
27//! let response = model.completion(request).await?;
28//! for item in response.choice {
29//! if let AssistantContent::Text(text) = item {
30//! println!("{}", text.text);
31//! }
32//! }
33//! # Ok(())
34//! # }
35//! ```
36
37use super::message::{AssistantContent, DocumentMediaType};
38use crate::message::ToolChoice;
39use crate::streaming::StreamingCompletionResponse;
40use crate::tool::server::ToolServerError;
41use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
42use crate::{OneOrMany, http_client};
43use crate::{
44 json_utils,
45 message::{Message, UserContent},
46 tool::ToolSetError,
47};
48use serde::de::DeserializeOwned;
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use std::ops::{Add, AddAssign};
52use thiserror::Error;
53
54// Errors
55#[derive(Debug, Error)]
56pub enum CompletionError {
57 /// Http error (e.g.: connection error, timeout, etc.)
58 #[error("HttpError: {0}")]
59 HttpError(#[from] http_client::Error),
60
61 /// Json error (e.g.: serialization, deserialization)
62 #[error("JsonError: {0}")]
63 JsonError(#[from] serde_json::Error),
64
65 /// Url error (e.g.: invalid URL)
66 #[error("UrlError: {0}")]
67 UrlError(#[from] url::ParseError),
68
69 #[cfg(not(target_family = "wasm"))]
70 /// Error building the completion request
71 #[error("RequestError: {0}")]
72 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
73
74 #[cfg(target_family = "wasm")]
75 /// Error building the completion request
76 #[error("RequestError: {0}")]
77 RequestError(#[from] Box<dyn std::error::Error + 'static>),
78
79 /// Error parsing the completion response
80 #[error("ResponseError: {0}")]
81 ResponseError(String),
82
83 /// Error returned by the completion model provider
84 #[error("ProviderError: {0}")]
85 ProviderError(String),
86}
87
88/// Prompt errors
89#[derive(Debug, Error)]
90pub enum PromptError {
91 /// Something went wrong with the completion
92 #[error("CompletionError: {0}")]
93 CompletionError(#[from] CompletionError),
94
95 /// There was an error while using a tool
96 #[error("ToolCallError: {0}")]
97 ToolError(#[from] ToolSetError),
98
99 /// There was an issue while executing a tool on a tool server
100 #[error("ToolServerError: {0}")]
101 ToolServerError(#[from] Box<ToolServerError>),
102
103 /// The LLM tried to call too many tools during a multi-turn conversation.
104 /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
105 /// or increase the amount of turns given in `.multi_turn()`.
106 #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
107 MaxTurnsError {
108 max_turns: usize,
109 chat_history: Box<Vec<Message>>,
110 prompt: Box<Message>,
111 },
112
113 /// A prompting loop was cancelled.
114 #[error("PromptCancelled: {reason}")]
115 PromptCancelled {
116 chat_history: Vec<Message>,
117 reason: String,
118 },
119
120 /// The model emitted a structured tool call for a tool Rig did not allow
121 /// for the current turn.
122 #[error(
123 "UnknownToolCall: model attempted to call unknown or disallowed tool `{tool_name}`. Available tools: {available_tools:?}. Allowed tools for this turn: {allowed_tools:?}"
124 )]
125 UnknownToolCall {
126 tool_name: String,
127 available_tools: Vec<String>,
128 allowed_tools: Vec<String>,
129 chat_history: Box<Vec<Message>>,
130 },
131}
132
133/// Surface [`crate::memory::ConversationMemory`] failures through the existing
134/// [`CompletionError::RequestError`] variant so adding memory support does not
135/// require a new top-level [`PromptError`] arm in downstream exhaustive matchers.
136impl From<crate::memory::MemoryError> for PromptError {
137 fn from(err: crate::memory::MemoryError) -> Self {
138 Self::CompletionError(CompletionError::RequestError(Box::new(err)))
139 }
140}
141
142impl PromptError {
143 pub(crate) fn prompt_cancelled(
144 chat_history: impl IntoIterator<Item = Message>,
145 reason: impl Into<String>,
146 ) -> Self {
147 Self::PromptCancelled {
148 chat_history: chat_history.into_iter().collect(),
149 reason: reason.into(),
150 }
151 }
152}
153
154/// Errors that can occur when using typed structured output via [`TypedPrompt::prompt_typed`].
155#[derive(Debug, Error)]
156pub enum StructuredOutputError {
157 /// An error occurred during the prompt execution.
158 #[error("PromptError: {0}")]
159 PromptError(#[from] Box<PromptError>),
160
161 /// Failed to deserialize the model's response into the target type.
162 #[error("DeserializationError: {0}")]
163 DeserializationError(#[from] serde_json::Error),
164
165 /// The model returned an empty response.
166 #[error("EmptyResponse: model returned no content")]
167 EmptyResponse,
168}
169
170#[derive(Clone, Debug, Deserialize, Serialize)]
171pub struct Document {
172 /// Stable document identifier included in the serialized context block.
173 pub id: String,
174 /// Text content passed to the model as retrieval or static context.
175 pub text: String,
176 /// Additional string metadata rendered before the document text.
177 #[serde(flatten)]
178 pub additional_props: HashMap<String, String>,
179}
180
181impl std::fmt::Display for Document {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 write!(
184 f,
185 concat!("<file id: {}>\n", "{}\n", "</file>\n"),
186 self.id,
187 if self.additional_props.is_empty() {
188 self.text.clone()
189 } else {
190 let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
191 sorted_props.sort_by(|a, b| a.0.cmp(b.0));
192 let metadata = sorted_props
193 .iter()
194 .map(|(k, v)| format!("{k}: {v:?}"))
195 .collect::<Vec<_>>()
196 .join(" ");
197 format!("<metadata {} />\n{}", metadata, self.text)
198 }
199 )
200 }
201}
202
203#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
204pub struct ToolDefinition {
205 /// Tool name exposed to the model. It must match the registered tool name.
206 pub name: String,
207 /// Human-readable description sent to the model.
208 pub description: String,
209 /// JSON Schema describing tool arguments.
210 pub parameters: serde_json::Value,
211}
212
213/// Provider-native tool definition.
214///
215/// Stored under `additional_params.tools` and forwarded by providers that support
216/// provider-managed tools.
217#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
218pub struct ProviderToolDefinition {
219 /// Tool type/kind name as expected by the target provider (for example `web_search`).
220 #[serde(rename = "type")]
221 pub kind: String,
222 /// Additional provider-specific configuration for this hosted tool.
223 #[serde(flatten, default, skip_serializing_if = "serde_json::Map::is_empty")]
224 pub config: serde_json::Map<String, serde_json::Value>,
225}
226
227impl ProviderToolDefinition {
228 /// Creates a provider-hosted tool definition by type.
229 pub fn new(kind: impl Into<String>) -> Self {
230 Self {
231 kind: kind.into(),
232 config: serde_json::Map::new(),
233 }
234 }
235
236 /// Adds a provider-specific configuration key/value.
237 pub fn with_config(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
238 self.config.insert(key.into(), value);
239 self
240 }
241}
242
243// ================================================================
244// Implementations
245// ================================================================
246/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
247pub trait Prompt: WasmCompatSend + WasmCompatSync {
248 /// Send a simple prompt to the underlying completion model.
249 ///
250 /// If the completion model's response is a message, then it is returned as a string.
251 ///
252 /// If the completion model's response is a tool call, then the tool is called and
253 /// the result is returned as a string.
254 ///
255 /// If the tool does not exist, or the tool call fails, then an error is returned.
256 fn prompt(
257 &self,
258 prompt: impl Into<Message> + WasmCompatSend,
259 ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
260}
261
262/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
263pub trait Chat: WasmCompatSend + WasmCompatSync {
264 /// Send a prompt with optional chat history to the underlying completion model.
265 ///
266 /// If the completion model's response is a message, then it is returned as a string.
267 ///
268 /// If the completion model's response is a tool call, then the tool is called and the result
269 /// is returned as a string.
270 ///
271 /// If the tool does not exist, or the tool call fails, then an error is returned.
272 ///
273 /// The prompt and any assistant or tool messages produced during the turn
274 /// are appended to `chat_history`. Callers should pass the current
275 /// conversation history and should not push the user prompt themselves
276 /// before calling this method.
277 fn chat(
278 &self,
279 prompt: impl Into<Message> + WasmCompatSend,
280 chat_history: &mut Vec<Message>,
281 ) -> impl std::future::Future<Output = Result<String, PromptError>> + WasmCompatSend;
282}
283
284/// Trait defining a high-level typed prompt interface for structured output.
285///
286/// This trait provides an ergonomic way to get typed responses from an LLM by automatically
287/// generating a JSON schema from the target type and deserializing the response.
288///
289/// # Example
290/// ```rust,ignore
291/// use rig_core::prelude::*;
292/// use schemars::JsonSchema;
293/// use serde::Deserialize;
294///
295/// #[derive(Debug, Deserialize, JsonSchema)]
296/// struct WeatherForecast {
297/// city: String,
298/// temperature_f: f64,
299/// conditions: String,
300/// }
301///
302/// let agent = client.agent("gpt-4o").build();
303/// let forecast: WeatherForecast = agent
304/// .prompt_typed("What's the weather in NYC?")
305/// .await?;
306/// ```
307pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
308 /// The type of the typed prompt request returned by `prompt_typed`.
309 type TypedRequest<T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
310 where
311 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'static;
312
313 /// Send a prompt and receive a typed structured response.
314 ///
315 /// The JSON schema for `T` is automatically generated and sent to the provider.
316 /// Providers that support native structured outputs will constrain the model's
317 /// response to match this schema.
318 ///
319 /// # Type Parameters
320 /// * `T` - The target type to deserialize the response into. Must implement
321 /// `JsonSchema` (for schema generation), `DeserializeOwned` (for deserialization),
322 /// and `WasmCompatSend` (for async compatibility).
323 ///
324 /// # Example
325 /// ```rust,ignore
326 /// // Type can be inferred
327 /// let forecast: WeatherForecast = agent.prompt_typed("What's the weather?").await?;
328 ///
329 /// // Or specified explicitly with turbofish
330 /// let forecast = agent.prompt_typed::<WeatherForecast>("What's the weather?").await?;
331 /// ```
332 fn prompt_typed<T>(&self, prompt: impl Into<Message> + WasmCompatSend) -> Self::TypedRequest<T>
333 where
334 T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
335}
336
337/// Trait defining a low-level LLM completion interface
338pub trait Completion<M: CompletionModel> {
339 /// Generates a completion request builder for the given `prompt` and `chat_history`.
340 /// This function is meant to be called by the user to further customize the
341 /// request at prompt time before sending it.
342 ///
343 /// ❗IMPORTANT: The type that implements this trait might have already
344 /// populated fields in the builder (the exact fields depend on the type).
345 /// For fields that have already been set by the model, calling the corresponding
346 /// method on the builder will overwrite the value set by the model.
347 ///
348 /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
349 /// contain the `preamble` provided when creating the agent.
350 fn completion<I, T>(
351 &self,
352 prompt: impl Into<Message> + WasmCompatSend,
353 chat_history: I,
354 ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
355 + WasmCompatSend
356 where
357 I: IntoIterator<Item = T> + WasmCompatSend,
358 T: Into<Message>;
359}
360
361/// General completion response struct that contains the high-level completion choice
362/// and the raw response. The completion choice contains one or more assistant content.
363#[derive(Debug)]
364pub struct CompletionResponse<T> {
365 /// The completion choice (represented by one or more assistant message content)
366 /// returned by the completion model provider
367 pub choice: OneOrMany<AssistantContent>,
368 /// Tokens used during prompting and responding
369 pub usage: Usage,
370 /// The raw response returned by the completion model provider
371 pub raw_response: T,
372 /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
373 /// Used to pair reasoning input items with their output items in multi-turn.
374 pub message_id: Option<String>,
375}
376
377/// A trait for grabbing the token usage of a completion response.
378///
379/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
380pub trait GetTokenUsage {
381 /// Returns token usage when the response type carries it.
382 fn token_usage(&self) -> Option<crate::completion::Usage>;
383}
384
385impl GetTokenUsage for () {
386 fn token_usage(&self) -> Option<crate::completion::Usage> {
387 None
388 }
389}
390
391impl<T> GetTokenUsage for Option<T>
392where
393 T: GetTokenUsage,
394{
395 fn token_usage(&self) -> Option<crate::completion::Usage> {
396 if let Some(usage) = self {
397 usage.token_usage()
398 } else {
399 None
400 }
401 }
402}
403
404/// Struct representing the token usage for a completion request.
405/// If tokens used are `0`, then the provider failed to supply token usage metrics.
406#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
407pub struct Usage {
408 /// The number of input ("prompt") tokens used in a given request.
409 pub input_tokens: u64,
410 /// The number of output ("completion") tokens used in a given request.
411 pub output_tokens: u64,
412 /// We store this separately as some providers may only report one number
413 pub total_tokens: u64,
414 /// The number of input tokens read from a provider-managed cache
415 pub cached_input_tokens: u64,
416 /// The number of input tokens written to a provider-managed cache
417 pub cache_creation_input_tokens: u64,
418 /// The number of tool-use prompt tokens used in a given request.
419 #[serde(default)]
420 pub tool_use_prompt_tokens: u64,
421 /// The number of tokens spent on internal reasoning / "thoughts" by reasoning-capable
422 /// models (e.g. Gemini thinking, Anthropic extended thinking, OpenAI o-series).
423 pub reasoning_tokens: u64,
424}
425
426impl Usage {
427 /// Creates a new instance of `Usage`.
428 pub fn new() -> Self {
429 Self {
430 input_tokens: 0,
431 output_tokens: 0,
432 total_tokens: 0,
433 cached_input_tokens: 0,
434 cache_creation_input_tokens: 0,
435 tool_use_prompt_tokens: 0,
436 reasoning_tokens: 0,
437 }
438 }
439}
440
441impl Default for Usage {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447impl Add for Usage {
448 type Output = Self;
449
450 fn add(self, other: Self) -> Self::Output {
451 Self {
452 input_tokens: self.input_tokens + other.input_tokens,
453 output_tokens: self.output_tokens + other.output_tokens,
454 total_tokens: self.total_tokens + other.total_tokens,
455 cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
456 cache_creation_input_tokens: self.cache_creation_input_tokens
457 + other.cache_creation_input_tokens,
458 tool_use_prompt_tokens: self.tool_use_prompt_tokens + other.tool_use_prompt_tokens,
459 reasoning_tokens: self.reasoning_tokens + other.reasoning_tokens,
460 }
461 }
462}
463
464impl AddAssign for Usage {
465 fn add_assign(&mut self, other: Self) {
466 self.input_tokens += other.input_tokens;
467 self.output_tokens += other.output_tokens;
468 self.total_tokens += other.total_tokens;
469 self.cached_input_tokens += other.cached_input_tokens;
470 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
471 self.tool_use_prompt_tokens += other.tool_use_prompt_tokens;
472 self.reasoning_tokens += other.reasoning_tokens;
473 }
474}
475
476/// Trait defining a completion model that can be used to generate completion responses.
477/// This trait is meant to be implemented by the user to define a custom completion model,
478/// either from a third party provider (e.g.: OpenAI) or a local model.
479pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
480 /// The raw response type returned by the underlying completion model.
481 type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
482 /// The raw response type returned by the underlying completion model when streaming.
483 type StreamingResponse: Clone
484 + Unpin
485 + WasmCompatSend
486 + WasmCompatSync
487 + Serialize
488 + DeserializeOwned
489 + GetTokenUsage;
490
491 /// Provider client type used to construct this model.
492 type Client;
493
494 /// Construct a model handle from a provider client and model identifier.
495 fn make(client: &Self::Client, model: impl Into<String>) -> Self;
496
497 /// Generates a completion response for the given completion request.
498 fn completion(
499 &self,
500 request: CompletionRequest,
501 ) -> impl std::future::Future<
502 Output = Result<CompletionResponse<Self::Response>, CompletionError>,
503 > + WasmCompatSend;
504
505 fn stream(
506 &self,
507 request: CompletionRequest,
508 ) -> impl std::future::Future<
509 Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
510 > + WasmCompatSend;
511
512 /// Generates a completion request builder for the given `prompt`.
513 fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
514 CompletionRequestBuilder::new(self.clone(), prompt)
515 }
516}
517
518/// Struct representing a general completion request that can be sent to a completion model provider.
519#[derive(Debug, Clone, Serialize, Deserialize)]
520pub struct CompletionRequest {
521 /// Optional model override for this request.
522 pub model: Option<String>,
523 /// Legacy preamble field preserved for backwards compatibility.
524 ///
525 /// New code should prefer a leading [`Message::System`]
526 /// in `chat_history` as the canonical representation of system instructions.
527 pub preamble: Option<String>,
528 /// The chat history to be sent to the completion model provider.
529 /// The very last message will always be the prompt (hence why there is *always* one)
530 pub chat_history: OneOrMany<Message>,
531 /// The documents to be sent to the completion model provider
532 pub documents: Vec<Document>,
533 /// The tools to be sent to the completion model provider
534 pub tools: Vec<ToolDefinition>,
535 /// The temperature to be sent to the completion model provider
536 pub temperature: Option<f64>,
537 /// The max tokens to be sent to the completion model provider
538 pub max_tokens: Option<u64>,
539 /// Whether tools are required to be used by the model provider or not before providing a response.
540 pub tool_choice: Option<ToolChoice>,
541 /// Additional provider-specific parameters to be sent to the completion model provider
542 pub additional_params: Option<serde_json::Value>,
543 /// Optional JSON Schema for structured output. When set, providers that support
544 /// native structured outputs will constrain the model's response to match this schema.
545 pub output_schema: Option<schemars::Schema>,
546}
547
548impl CompletionRequest {
549 /// Extracts a name from the output schema's `"title"` field, falling back to `"response_schema"`.
550 /// Useful for providers that require a name alongside the JSON Schema (e.g., OpenAI).
551 pub fn output_schema_name(&self) -> Option<String> {
552 self.output_schema.as_ref().map(|schema| {
553 schema
554 .as_object()
555 .and_then(|o| o.get("title"))
556 .and_then(|v| v.as_str())
557 .unwrap_or("response_schema")
558 .to_string()
559 })
560 }
561
562 /// Returns documents normalized into a message (if any).
563 /// Most providers do not accept documents directly as input, so it needs to convert into a
564 /// `Message` so that it can be incorporated into `chat_history`.
565 pub fn normalized_documents(&self) -> Option<Message> {
566 if self.documents.is_empty() {
567 return None;
568 }
569
570 // Most providers will convert documents into a text unless it can handle document messages.
571 // We use `UserContent::document` for those who handle it directly!
572 let messages = self
573 .documents
574 .iter()
575 .map(|doc| {
576 UserContent::document(
577 doc.to_string(),
578 // In the future, we can customize `Document` to pass these extra types through.
579 // Most providers ditch these but they might want to use them.
580 Some(DocumentMediaType::TXT),
581 )
582 })
583 .collect::<Vec<_>>();
584
585 OneOrMany::from_iter_optional(messages).map(|content| Message::User { content })
586 }
587
588 /// Adds a provider-hosted tool by storing it in `additional_params.tools`.
589 pub fn with_provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
590 self.additional_params =
591 merge_provider_tools_into_additional_params(self.additional_params, vec![tool]);
592 self
593 }
594
595 /// Adds provider-hosted tools by storing them in `additional_params.tools`.
596 pub fn with_provider_tools(mut self, tools: Vec<ProviderToolDefinition>) -> Self {
597 self.additional_params =
598 merge_provider_tools_into_additional_params(self.additional_params, tools);
599 self
600 }
601}
602
603fn merge_provider_tools_into_additional_params(
604 additional_params: Option<serde_json::Value>,
605 provider_tools: Vec<ProviderToolDefinition>,
606) -> Option<serde_json::Value> {
607 if provider_tools.is_empty() {
608 return additional_params;
609 }
610
611 let mut provider_tools_json = provider_tools
612 .into_iter()
613 .map(|ProviderToolDefinition { kind, mut config }| {
614 // Force the provider tool type from the strongly-typed field.
615 config.insert("type".to_string(), serde_json::Value::String(kind));
616 serde_json::Value::Object(config)
617 })
618 .collect::<Vec<_>>();
619
620 let mut params_map = match additional_params {
621 Some(serde_json::Value::Object(map)) => map,
622 Some(serde_json::Value::Bool(stream)) => {
623 let mut map = serde_json::Map::new();
624 map.insert("stream".to_string(), serde_json::Value::Bool(stream));
625 map
626 }
627 _ => serde_json::Map::new(),
628 };
629
630 let mut merged_tools = match params_map.remove("tools") {
631 Some(serde_json::Value::Array(existing)) => existing,
632 _ => Vec::new(),
633 };
634 merged_tools.append(&mut provider_tools_json);
635 params_map.insert("tools".to_string(), serde_json::Value::Array(merged_tools));
636 Some(serde_json::Value::Object(params_map))
637}
638
639/// Builder struct for constructing a completion request.
640///
641/// Example usage:
642/// ```no_run
643/// use rig_core::{
644/// client::CompletionClient,
645/// providers::openai::{Client, self},
646/// completion::{CompletionModel, CompletionRequestBuilder},
647/// };
648///
649/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
650/// let openai = Client::new("your-openai-api-key")?;
651/// let model = openai.completion_model(openai::GPT_5_2);
652///
653/// // Create the completion request and execute it separately
654/// let request = CompletionRequestBuilder::new(model.clone(), "Who are you?".to_string())
655/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
656/// .temperature(0.5)
657/// .build();
658///
659/// let response = model.completion(request).await?;
660/// # Ok(())
661/// # }
662/// ```
663///
664/// Alternatively, you can execute the completion request directly from the builder:
665/// ```no_run
666/// use rig_core::{
667/// client::CompletionClient,
668/// providers::openai::{Client, self},
669/// completion::CompletionRequestBuilder,
670/// };
671///
672/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
673/// let openai = Client::new("your-openai-api-key")?;
674/// let model = openai.completion_model(openai::GPT_5_2);
675///
676/// // Create the completion request and execute it directly
677/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
678/// .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
679/// .temperature(0.5)
680/// .send()
681/// .await?;
682/// # Ok(())
683/// # }
684/// ```
685///
686/// Note: It is usually unnecessary to create a completion request builder directly.
687/// Instead, use the [CompletionModel::completion_request] method.
688pub struct CompletionRequestBuilder<M: CompletionModel> {
689 model: M,
690 prompt: Message,
691 request_model: Option<String>,
692 preamble: Option<String>,
693 chat_history: Vec<Message>,
694 documents: Vec<Document>,
695 tools: Vec<ToolDefinition>,
696 provider_tools: Vec<ProviderToolDefinition>,
697 temperature: Option<f64>,
698 max_tokens: Option<u64>,
699 tool_choice: Option<ToolChoice>,
700 additional_params: Option<serde_json::Value>,
701 output_schema: Option<schemars::Schema>,
702}
703
704impl<M: CompletionModel> CompletionRequestBuilder<M> {
705 pub fn new(model: M, prompt: impl Into<Message>) -> Self {
706 Self {
707 model,
708 prompt: prompt.into(),
709 request_model: None,
710 preamble: None,
711 chat_history: Vec::new(),
712 documents: Vec::new(),
713 tools: Vec::new(),
714 provider_tools: Vec::new(),
715 temperature: None,
716 max_tokens: None,
717 tool_choice: None,
718 additional_params: None,
719 output_schema: None,
720 }
721 }
722
723 /// Sets the preamble for the completion request.
724 pub fn preamble(mut self, preamble: String) -> Self {
725 // Legacy public API: funnel preamble into canonical system messages at build-time.
726 self.preamble = Some(preamble);
727 self
728 }
729
730 /// Overrides the model used for this request.
731 pub fn model(mut self, model: impl Into<String>) -> Self {
732 self.request_model = Some(model.into());
733 self
734 }
735
736 /// Overrides the model used for this request.
737 pub fn model_opt(mut self, model: Option<String>) -> Self {
738 self.request_model = model;
739 self
740 }
741
742 pub fn without_preamble(mut self) -> Self {
743 self.preamble = None;
744 self
745 }
746
747 /// Adds a message to the chat history for the completion request.
748 pub fn message(mut self, message: Message) -> Self {
749 self.chat_history.push(message);
750
751 self
752 }
753
754 /// Adds a list of messages to the chat history for the completion request.
755 pub fn messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
756 self.chat_history.extend(messages);
757
758 self
759 }
760
761 /// Adds a document to the completion request.
762 pub fn document(mut self, document: Document) -> Self {
763 self.documents.push(document);
764 self
765 }
766
767 /// Adds a list of documents to the completion request.
768 pub fn documents(self, documents: impl IntoIterator<Item = Document>) -> Self {
769 documents
770 .into_iter()
771 .fold(self, |builder, doc| builder.document(doc))
772 }
773
774 /// Adds a tool to the completion request.
775 pub fn tool(mut self, tool: ToolDefinition) -> Self {
776 self.tools.push(tool);
777 self
778 }
779
780 /// Adds a list of tools to the completion request.
781 pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
782 tools
783 .into_iter()
784 .fold(self, |builder, tool| builder.tool(tool))
785 }
786
787 /// Adds a provider-hosted tool to the completion request.
788 pub fn provider_tool(mut self, tool: ProviderToolDefinition) -> Self {
789 self.provider_tools.push(tool);
790 self
791 }
792
793 /// Adds provider-hosted tools to the completion request.
794 pub fn provider_tools(self, tools: Vec<ProviderToolDefinition>) -> Self {
795 tools
796 .into_iter()
797 .fold(self, |builder, tool| builder.provider_tool(tool))
798 }
799
800 /// Adds additional parameters to the completion request.
801 /// This can be used to set additional provider-specific parameters. For example,
802 /// Cohere's completion models accept a `connectors` parameter that can be used to
803 /// specify the data connectors used by Cohere when executing the completion
804 /// (see `examples/cohere_connectors.rs`).
805 pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
806 match self.additional_params {
807 Some(params) => {
808 self.additional_params = Some(json_utils::merge(params, additional_params));
809 }
810 None => {
811 self.additional_params = Some(additional_params);
812 }
813 }
814 self
815 }
816
817 /// Sets the additional parameters for the completion request.
818 /// This can be used to set additional provider-specific parameters. For example,
819 /// Cohere's completion models accept a `connectors` parameter that can be used to
820 /// specify the data connectors used by Cohere when executing the completion
821 /// (see `examples/cohere_connectors.rs`).
822 pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
823 self.additional_params = additional_params;
824 self
825 }
826
827 /// Sets the temperature for the completion request.
828 pub fn temperature(mut self, temperature: f64) -> Self {
829 self.temperature = Some(temperature);
830 self
831 }
832
833 /// Sets the temperature for the completion request.
834 pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
835 self.temperature = temperature;
836 self
837 }
838
839 /// Sets the max tokens for the completion request.
840 /// Note: This is required if using Anthropic
841 pub fn max_tokens(mut self, max_tokens: u64) -> Self {
842 self.max_tokens = Some(max_tokens);
843 self
844 }
845
846 /// Sets the max tokens for the completion request.
847 /// Note: This is required if using Anthropic
848 pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
849 self.max_tokens = max_tokens;
850 self
851 }
852
853 /// Sets the thing.
854 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
855 self.tool_choice = Some(tool_choice);
856 self
857 }
858
859 /// Sets the output schema for structured output. When set, providers that support
860 /// native structured outputs will constrain the model's response to match this schema.
861 /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
862 /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
863 /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
864 /// to still be able to leverage structured outputs.
865 pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
866 self.output_schema = Some(schema);
867 self
868 }
869
870 /// Sets the output schema for structured output from an optional value.
871 /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
872 /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
873 /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
874 /// to still be able to leverage structured outputs.
875 pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
876 self.output_schema = schema;
877 self
878 }
879
880 /// Builds the completion request.
881 pub fn build(self) -> CompletionRequest {
882 // Build the final message list, prepending preamble if present
883 let mut chat_history = self.chat_history;
884 let prompt = self.prompt;
885 if let Some(preamble) = self.preamble {
886 chat_history.insert(0, Message::system(preamble));
887 }
888 chat_history.push(prompt.clone());
889
890 let chat_history =
891 OneOrMany::from_iter_optional(chat_history).unwrap_or_else(|| OneOrMany::one(prompt));
892 let additional_params = merge_provider_tools_into_additional_params(
893 self.additional_params,
894 self.provider_tools,
895 );
896
897 CompletionRequest {
898 model: self.request_model,
899 preamble: None,
900 chat_history,
901 documents: self.documents,
902 tools: self.tools,
903 temperature: self.temperature,
904 max_tokens: self.max_tokens,
905 tool_choice: self.tool_choice,
906 additional_params,
907 output_schema: self.output_schema,
908 }
909 }
910
911 /// Sends the completion request to the completion model provider and returns the completion response.
912 pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
913 let model = self.model.clone();
914 model.completion(self.build()).await
915 }
916
917 /// Stream the completion request
918 pub async fn stream<'a>(
919 self,
920 ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
921 where
922 <M as CompletionModel>::StreamingResponse: 'a,
923 Self: 'a,
924 {
925 let model = self.model.clone();
926 model.stream(self.build()).await
927 }
928}
929
930#[cfg(test)]
931mod tests {
932
933 use super::*;
934 use crate::test_utils::MockCompletionModel;
935
936 #[test]
937 fn test_document_display_without_metadata() {
938 let doc = Document {
939 id: "123".to_string(),
940 text: "This is a test document.".to_string(),
941 additional_props: HashMap::new(),
942 };
943
944 let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
945 assert_eq!(format!("{doc}"), expected);
946 }
947
948 #[test]
949 fn test_document_display_with_metadata() {
950 let mut additional_props = HashMap::new();
951 additional_props.insert("author".to_string(), "John Doe".to_string());
952 additional_props.insert("length".to_string(), "42".to_string());
953
954 let doc = Document {
955 id: "123".to_string(),
956 text: "This is a test document.".to_string(),
957 additional_props,
958 };
959
960 let expected = concat!(
961 "<file id: 123>\n",
962 "<metadata author: \"John Doe\" length: \"42\" />\n",
963 "This is a test document.\n",
964 "</file>\n"
965 );
966 assert_eq!(format!("{doc}"), expected);
967 }
968
969 #[test]
970 fn test_normalize_documents_with_documents() {
971 let doc1 = Document {
972 id: "doc1".to_string(),
973 text: "Document 1 text.".to_string(),
974 additional_props: HashMap::new(),
975 };
976
977 let doc2 = Document {
978 id: "doc2".to_string(),
979 text: "Document 2 text.".to_string(),
980 additional_props: HashMap::new(),
981 };
982
983 let request = CompletionRequest {
984 model: None,
985 preamble: None,
986 chat_history: OneOrMany::one("What is the capital of France?".into()),
987 documents: vec![doc1, doc2],
988 tools: Vec::new(),
989 temperature: None,
990 max_tokens: None,
991 tool_choice: None,
992 additional_params: None,
993 output_schema: None,
994 };
995
996 let expected = Message::User {
997 content: OneOrMany::many(vec![
998 UserContent::document(
999 "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
1000 Some(DocumentMediaType::TXT),
1001 ),
1002 UserContent::document(
1003 "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
1004 Some(DocumentMediaType::TXT),
1005 ),
1006 ])
1007 .expect("There will be at least one document"),
1008 };
1009
1010 assert_eq!(request.normalized_documents(), Some(expected));
1011 }
1012
1013 #[test]
1014 fn test_normalize_documents_without_documents() {
1015 let request = CompletionRequest {
1016 model: None,
1017 preamble: None,
1018 chat_history: OneOrMany::one("What is the capital of France?".into()),
1019 documents: Vec::new(),
1020 tools: Vec::new(),
1021 temperature: None,
1022 max_tokens: None,
1023 tool_choice: None,
1024 additional_params: None,
1025 output_schema: None,
1026 };
1027
1028 assert_eq!(request.normalized_documents(), None);
1029 }
1030
1031 #[test]
1032 fn preamble_builder_funnels_to_system_message() {
1033 let request =
1034 CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1035 .preamble("System prompt".to_string())
1036 .message(Message::user("History"))
1037 .build();
1038
1039 assert_eq!(request.preamble, None);
1040
1041 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1042 assert_eq!(history.len(), 3);
1043 assert!(matches!(
1044 &history[0],
1045 Message::System { content } if content == "System prompt"
1046 ));
1047 assert!(matches!(&history[1], Message::User { .. }));
1048 assert!(matches!(&history[2], Message::User { .. }));
1049 }
1050
1051 #[test]
1052 fn without_preamble_removes_legacy_preamble_injection() {
1053 let request =
1054 CompletionRequestBuilder::new(MockCompletionModel::default(), Message::user("Prompt"))
1055 .preamble("System prompt".to_string())
1056 .without_preamble()
1057 .build();
1058
1059 assert_eq!(request.preamble, None);
1060 let history = request.chat_history.into_iter().collect::<Vec<_>>();
1061 assert_eq!(history.len(), 1);
1062 assert!(matches!(&history[0], Message::User { .. }));
1063 }
1064}