Skip to main content

rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76    json_utils,
77    message::{Message, UserContent},
78    tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87// Errors
88#[derive(Debug, Error)]
89pub enum CompletionError {
90    /// Http error (e.g.: connection error, timeout, etc.)
91    #[error("HttpError: {0}")]
92    HttpError(#[from] http_client::Error),
93
94    /// Json error (e.g.: serialization, deserialization)
95    #[error("JsonError: {0}")]
96    JsonError(#[from] serde_json::Error),
97
98    /// Url error (e.g.: invalid URL)
99    #[error("UrlError: {0}")]
100    UrlError(#[from] url::ParseError),
101
102    #[cfg(not(target_family = "wasm"))]
103    /// Error building the completion request
104    #[error("RequestError: {0}")]
105    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107    #[cfg(target_family = "wasm")]
108    /// Error building the completion request
109    #[error("RequestError: {0}")]
110    RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112    /// Error parsing the completion response
113    #[error("ResponseError: {0}")]
114    ResponseError(String),
115
116    /// Error returned by the completion model provider
117    #[error("ProviderError: {0}")]
118    ProviderError(String),
119}
120
121/// Prompt errors
122#[derive(Debug, Error)]
123pub enum PromptError {
124    /// Something went wrong with the completion
125    #[error("CompletionError: {0}")]
126    CompletionError(#[from] CompletionError),
127
128    /// There was an error while using a tool
129    #[error("ToolCallError: {0}")]
130    ToolError(#[from] ToolSetError),
131
132    /// There was an issue while executing a tool on a tool server
133    #[error("ToolServerError: {0}")]
134    ToolServerError(#[from] ToolServerError),
135
136    /// The LLM tried to call too many tools during a multi-turn conversation.
137    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
138    /// or increase the amount of turns given in `.multi_turn()`.
139    #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
140    MaxTurnsError {
141        max_turns: usize,
142        chat_history: Box<Vec<Message>>,
143        prompt: Box<Message>,
144    },
145
146    /// A prompting loop was cancelled.
147    #[error("PromptCancelled: {reason}")]
148    PromptCancelled {
149        chat_history: Box<Vec<Message>>,
150        reason: String,
151    },
152}
153
154impl PromptError {
155    pub(crate) fn prompt_cancelled(chat_history: Vec<Message>, reason: impl Into<String>) -> Self {
156        Self::PromptCancelled {
157            chat_history: Box::new(chat_history),
158            reason: reason.into(),
159        }
160    }
161}
162
163/// Errors that can occur when using typed structured output via [`TypedPrompt::prompt_typed`].
164#[derive(Debug, Error)]
165pub enum StructuredOutputError {
166    /// An error occurred during the prompt execution.
167    #[error("PromptError: {0}")]
168    PromptError(#[from] PromptError),
169
170    /// Failed to deserialize the model's response into the target type.
171    #[error("DeserializationError: {0}")]
172    DeserializationError(#[from] serde_json::Error),
173
174    /// The model returned an empty response.
175    #[error("EmptyResponse: model returned no content")]
176    EmptyResponse,
177}
178
179#[derive(Clone, Debug, Deserialize, Serialize)]
180pub struct Document {
181    pub id: String,
182    pub text: String,
183    #[serde(flatten)]
184    pub additional_props: HashMap<String, String>,
185}
186
187impl std::fmt::Display for Document {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        write!(
190            f,
191            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
192            self.id,
193            if self.additional_props.is_empty() {
194                self.text.clone()
195            } else {
196                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
197                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
198                let metadata = sorted_props
199                    .iter()
200                    .map(|(k, v)| format!("{k}: {v:?}"))
201                    .collect::<Vec<_>>()
202                    .join(" ");
203                format!("<metadata {} />\n{}", metadata, self.text)
204            }
205        )
206    }
207}
208
209#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
210pub struct ToolDefinition {
211    pub name: String,
212    pub description: String,
213    pub parameters: serde_json::Value,
214}
215
216// ================================================================
217// Implementations
218// ================================================================
219/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
220pub trait Prompt: WasmCompatSend + WasmCompatSync {
221    /// Send a simple prompt to the underlying completion model.
222    ///
223    /// If the completion model's response is a message, then it is returned as a string.
224    ///
225    /// If the completion model's response is a tool call, then the tool is called and
226    /// the result is returned as a string.
227    ///
228    /// If the tool does not exist, or the tool call fails, then an error is returned.
229    fn prompt(
230        &self,
231        prompt: impl Into<Message> + WasmCompatSend,
232    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
233}
234
235/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
236pub trait Chat: WasmCompatSend + WasmCompatSync {
237    /// Send a prompt with optional chat history to the underlying completion model.
238    ///
239    /// If the completion model's response is a message, then it is returned as a string.
240    ///
241    /// If the completion model's response is a tool call, then the tool is called and the result
242    /// is returned as a string.
243    ///
244    /// If the tool does not exist, or the tool call fails, then an error is returned.
245    fn chat(
246        &self,
247        prompt: impl Into<Message> + WasmCompatSend,
248        chat_history: Vec<Message>,
249    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
250}
251
252/// Trait defining a high-level typed prompt interface for structured output.
253///
254/// This trait provides an ergonomic way to get typed responses from an LLM by automatically
255/// generating a JSON schema from the target type and deserializing the response.
256///
257/// # Example
258/// ```rust,ignore
259/// use rig::prelude::*;
260/// use schemars::JsonSchema;
261/// use serde::Deserialize;
262///
263/// #[derive(Debug, Deserialize, JsonSchema)]
264/// struct WeatherForecast {
265///     city: String,
266///     temperature_f: f64,
267///     conditions: String,
268/// }
269///
270/// let agent = client.agent("gpt-4o").build();
271/// let forecast: WeatherForecast = agent
272///     .prompt_typed("What's the weather in NYC?")
273///     .await?;
274/// ```
275pub trait TypedPrompt: WasmCompatSend + WasmCompatSync {
276    /// The type of the typed prompt request returned by `prompt_typed`.
277    type TypedRequest<'a, T>: std::future::IntoFuture<Output = Result<T, StructuredOutputError>>
278    where
279        Self: 'a,
280        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend + 'a;
281
282    /// Send a prompt and receive a typed structured response.
283    ///
284    /// The JSON schema for `T` is automatically generated and sent to the provider.
285    /// Providers that support native structured outputs will constrain the model's
286    /// response to match this schema.
287    ///
288    /// # Type Parameters
289    /// * `T` - The target type to deserialize the response into. Must implement
290    ///   `JsonSchema` (for schema generation), `DeserializeOwned` (for deserialization),
291    ///   and `WasmCompatSend` (for async compatibility).
292    ///
293    /// # Example
294    /// ```rust,ignore
295    /// // Type can be inferred
296    /// let forecast: WeatherForecast = agent.prompt_typed("What's the weather?").await?;
297    ///
298    /// // Or specified explicitly with turbofish
299    /// let forecast = agent.prompt_typed::<WeatherForecast>("What's the weather?").await?;
300    /// ```
301    fn prompt_typed<T>(
302        &self,
303        prompt: impl Into<Message> + WasmCompatSend,
304    ) -> Self::TypedRequest<'_, T>
305    where
306        T: schemars::JsonSchema + DeserializeOwned + WasmCompatSend;
307}
308
309/// Trait defining a low-level LLM completion interface
310pub trait Completion<M: CompletionModel> {
311    /// Generates a completion request builder for the given `prompt` and `chat_history`.
312    /// This function is meant to be called by the user to further customize the
313    /// request at prompt time before sending it.
314    ///
315    /// ❗IMPORTANT: The type that implements this trait might have already
316    /// populated fields in the builder (the exact fields depend on the type).
317    /// For fields that have already been set by the model, calling the corresponding
318    /// method on the builder will overwrite the value set by the model.
319    ///
320    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
321    /// contain the `preamble` provided when creating the agent.
322    fn completion(
323        &self,
324        prompt: impl Into<Message> + WasmCompatSend,
325        chat_history: Vec<Message>,
326    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
327    + WasmCompatSend;
328}
329
330/// General completion response struct that contains the high-level completion choice
331/// and the raw response. The completion choice contains one or more assistant content.
332#[derive(Debug)]
333pub struct CompletionResponse<T> {
334    /// The completion choice (represented by one or more assistant message content)
335    /// returned by the completion model provider
336    pub choice: OneOrMany<AssistantContent>,
337    /// Tokens used during prompting and responding
338    pub usage: Usage,
339    /// The raw response returned by the completion model provider
340    pub raw_response: T,
341    /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID).
342    /// Used to pair reasoning input items with their output items in multi-turn.
343    pub message_id: Option<String>,
344}
345
346/// A trait for grabbing the token usage of a completion response.
347///
348/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
349pub trait GetTokenUsage {
350    fn token_usage(&self) -> Option<crate::completion::Usage>;
351}
352
353impl GetTokenUsage for () {
354    fn token_usage(&self) -> Option<crate::completion::Usage> {
355        None
356    }
357}
358
359impl<T> GetTokenUsage for Option<T>
360where
361    T: GetTokenUsage,
362{
363    fn token_usage(&self) -> Option<crate::completion::Usage> {
364        if let Some(usage) = self {
365            usage.token_usage()
366        } else {
367            None
368        }
369    }
370}
371
372/// Struct representing the token usage for a completion request.
373/// If tokens used are `0`, then the provider failed to supply token usage metrics.
374#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
375pub struct Usage {
376    /// The number of input ("prompt") tokens used in a given request.
377    pub input_tokens: u64,
378    /// The number of output ("completion") tokens used in a given request.
379    pub output_tokens: u64,
380    /// We store this separately as some providers may only report one number
381    pub total_tokens: u64,
382    /// The number of cached input tokens (from prompt caching). 0 if not reported by provider.
383    pub cached_input_tokens: u64,
384}
385
386impl Usage {
387    /// Creates a new instance of `Usage`.
388    pub fn new() -> Self {
389        Self {
390            input_tokens: 0,
391            output_tokens: 0,
392            total_tokens: 0,
393            cached_input_tokens: 0,
394        }
395    }
396}
397
398impl Default for Usage {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404impl Add for Usage {
405    type Output = Self;
406
407    fn add(self, other: Self) -> Self::Output {
408        Self {
409            input_tokens: self.input_tokens + other.input_tokens,
410            output_tokens: self.output_tokens + other.output_tokens,
411            total_tokens: self.total_tokens + other.total_tokens,
412            cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
413        }
414    }
415}
416
417impl AddAssign for Usage {
418    fn add_assign(&mut self, other: Self) {
419        self.input_tokens += other.input_tokens;
420        self.output_tokens += other.output_tokens;
421        self.total_tokens += other.total_tokens;
422        self.cached_input_tokens += other.cached_input_tokens;
423    }
424}
425
426/// Trait defining a completion model that can be used to generate completion responses.
427/// This trait is meant to be implemented by the user to define a custom completion model,
428/// either from a third party provider (e.g.: OpenAI) or a local model.
429pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
430    /// The raw response type returned by the underlying completion model.
431    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
432    /// The raw response type returned by the underlying completion model when streaming.
433    type StreamingResponse: Clone
434        + Unpin
435        + WasmCompatSend
436        + WasmCompatSync
437        + Serialize
438        + DeserializeOwned
439        + GetTokenUsage;
440
441    type Client;
442
443    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
444
445    /// Generates a completion response for the given completion request.
446    fn completion(
447        &self,
448        request: CompletionRequest,
449    ) -> impl std::future::Future<
450        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
451    > + WasmCompatSend;
452
453    fn stream(
454        &self,
455        request: CompletionRequest,
456    ) -> impl std::future::Future<
457        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
458    > + WasmCompatSend;
459
460    /// Generates a completion request builder for the given `prompt`.
461    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
462        CompletionRequestBuilder::new(self.clone(), prompt)
463    }
464}
465
466#[allow(deprecated)]
467#[deprecated(
468    since = "0.25.0",
469    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
470)]
471pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
472    fn completion(
473        &self,
474        request: CompletionRequest,
475    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
476
477    fn stream(
478        &self,
479        request: CompletionRequest,
480    ) -> WasmBoxedFuture<
481        '_,
482        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
483    >;
484
485    fn completion_request(
486        &self,
487        prompt: Message,
488    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
489}
490
491#[allow(deprecated)]
492impl<T, R> CompletionModelDyn for T
493where
494    T: CompletionModel<StreamingResponse = R>,
495    R: Clone + Unpin + GetTokenUsage + 'static,
496{
497    fn completion(
498        &self,
499        request: CompletionRequest,
500    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
501        Box::pin(async move {
502            self.completion(request)
503                .await
504                .map(|resp| CompletionResponse {
505                    choice: resp.choice,
506                    usage: resp.usage,
507                    raw_response: (),
508                    message_id: resp.message_id,
509                })
510        })
511    }
512
513    fn stream(
514        &self,
515        request: CompletionRequest,
516    ) -> WasmBoxedFuture<
517        '_,
518        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
519    > {
520        Box::pin(async move {
521            let resp = self.stream(request).await?;
522            let inner = resp.inner;
523
524            let stream = streaming::StreamingResultDyn {
525                inner: Box::pin(inner),
526            };
527
528            Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
529        })
530    }
531
532    /// Generates a completion request builder for the given `prompt`.
533    fn completion_request(
534        &self,
535        prompt: Message,
536    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
537        CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
538    }
539}
540
541/// Struct representing a general completion request that can be sent to a completion model provider.
542#[derive(Debug, Clone)]
543pub struct CompletionRequest {
544    /// Optional model override for this request.
545    pub model: Option<String>,
546    /// The preamble to be sent to the completion model provider
547    pub preamble: Option<String>,
548    /// The chat history to be sent to the completion model provider.
549    /// The very last message will always be the prompt (hence why there is *always* one)
550    pub chat_history: OneOrMany<Message>,
551    /// The documents to be sent to the completion model provider
552    pub documents: Vec<Document>,
553    /// The tools to be sent to the completion model provider
554    pub tools: Vec<ToolDefinition>,
555    /// The temperature to be sent to the completion model provider
556    pub temperature: Option<f64>,
557    /// The max tokens to be sent to the completion model provider
558    pub max_tokens: Option<u64>,
559    /// Whether tools are required to be used by the model provider or not before providing a response.
560    pub tool_choice: Option<ToolChoice>,
561    /// Additional provider-specific parameters to be sent to the completion model provider
562    pub additional_params: Option<serde_json::Value>,
563    /// Optional JSON Schema for structured output. When set, providers that support
564    /// native structured outputs will constrain the model's response to match this schema.
565    pub output_schema: Option<schemars::Schema>,
566}
567
568impl CompletionRequest {
569    /// Extracts a name from the output schema's `"title"` field, falling back to `"response_schema"`.
570    /// Useful for providers that require a name alongside the JSON Schema (e.g., OpenAI).
571    pub fn output_schema_name(&self) -> Option<String> {
572        self.output_schema.as_ref().map(|schema| {
573            schema
574                .as_object()
575                .and_then(|o| o.get("title"))
576                .and_then(|v| v.as_str())
577                .unwrap_or("response_schema")
578                .to_string()
579        })
580    }
581
582    /// Returns documents normalized into a message (if any).
583    /// Most providers do not accept documents directly as input, so it needs to convert into a
584    ///  `Message` so that it can be incorporated into `chat_history` as a
585    pub fn normalized_documents(&self) -> Option<Message> {
586        if self.documents.is_empty() {
587            return None;
588        }
589
590        // Most providers will convert documents into a text unless it can handle document messages.
591        // We use `UserContent::document` for those who handle it directly!
592        let messages = self
593            .documents
594            .iter()
595            .map(|doc| {
596                UserContent::document(
597                    doc.to_string(),
598                    // In the future, we can customize `Document` to pass these extra types through.
599                    // Most providers ditch these but they might want to use them.
600                    Some(DocumentMediaType::TXT),
601                )
602            })
603            .collect::<Vec<_>>();
604
605        Some(Message::User {
606            content: OneOrMany::many(messages).expect("There will be atleast one document"),
607        })
608    }
609}
610
611/// Builder struct for constructing a completion request.
612///
613/// Example usage:
614/// ```rust
615/// use rig::{
616///     providers::openai::{Client, self},
617///     completion::CompletionRequestBuilder,
618/// };
619///
620/// let openai = Client::new("your-openai-api-key");
621/// let model = openai.completion_model(openai::GPT_4O).build();
622///
623/// // Create the completion request and execute it separately
624/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
625///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
626///     .temperature(0.5)
627///     .build();
628///
629/// let response = model.completion(request)
630///     .await
631///     .expect("Failed to get completion response");
632/// ```
633///
634/// Alternatively, you can execute the completion request directly from the builder:
635/// ```rust
636/// use rig::{
637///     providers::openai::{Client, self},
638///     completion::CompletionRequestBuilder,
639/// };
640///
641/// let openai = Client::new("your-openai-api-key");
642/// let model = openai.completion_model(openai::GPT_4O).build();
643///
644/// // Create the completion request and execute it directly
645/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
646///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
647///     .temperature(0.5)
648///     .send()
649///     .await
650///     .expect("Failed to get completion response");
651/// ```
652///
653/// Note: It is usually unnecessary to create a completion request builder directly.
654/// Instead, use the [CompletionModel::completion_request] method.
655pub struct CompletionRequestBuilder<M: CompletionModel> {
656    model: M,
657    prompt: Message,
658    request_model: Option<String>,
659    preamble: Option<String>,
660    chat_history: Vec<Message>,
661    documents: Vec<Document>,
662    tools: Vec<ToolDefinition>,
663    temperature: Option<f64>,
664    max_tokens: Option<u64>,
665    tool_choice: Option<ToolChoice>,
666    additional_params: Option<serde_json::Value>,
667    output_schema: Option<schemars::Schema>,
668}
669
670impl<M: CompletionModel> CompletionRequestBuilder<M> {
671    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
672        Self {
673            model,
674            prompt: prompt.into(),
675            request_model: None,
676            preamble: None,
677            chat_history: Vec::new(),
678            documents: Vec::new(),
679            tools: Vec::new(),
680            temperature: None,
681            max_tokens: None,
682            tool_choice: None,
683            additional_params: None,
684            output_schema: None,
685        }
686    }
687
688    /// Sets the preamble for the completion request.
689    pub fn preamble(mut self, preamble: String) -> Self {
690        self.preamble = Some(preamble);
691        self
692    }
693
694    /// Overrides the model used for this request.
695    pub fn model(mut self, model: impl Into<String>) -> Self {
696        self.request_model = Some(model.into());
697        self
698    }
699
700    /// Overrides the model used for this request.
701    pub fn model_opt(mut self, model: Option<String>) -> Self {
702        self.request_model = model;
703        self
704    }
705
706    pub fn without_preamble(mut self) -> Self {
707        self.preamble = None;
708        self
709    }
710
711    /// Adds a message to the chat history for the completion request.
712    pub fn message(mut self, message: Message) -> Self {
713        self.chat_history.push(message);
714        self
715    }
716
717    /// Adds a list of messages to the chat history for the completion request.
718    pub fn messages(self, messages: Vec<Message>) -> Self {
719        messages
720            .into_iter()
721            .fold(self, |builder, msg| builder.message(msg))
722    }
723
724    /// Adds a document to the completion request.
725    pub fn document(mut self, document: Document) -> Self {
726        self.documents.push(document);
727        self
728    }
729
730    /// Adds a list of documents to the completion request.
731    pub fn documents(self, documents: Vec<Document>) -> Self {
732        documents
733            .into_iter()
734            .fold(self, |builder, doc| builder.document(doc))
735    }
736
737    /// Adds a tool to the completion request.
738    pub fn tool(mut self, tool: ToolDefinition) -> Self {
739        self.tools.push(tool);
740        self
741    }
742
743    /// Adds a list of tools to the completion request.
744    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
745        tools
746            .into_iter()
747            .fold(self, |builder, tool| builder.tool(tool))
748    }
749
750    /// Adds additional parameters to the completion request.
751    /// This can be used to set additional provider-specific parameters. For example,
752    /// Cohere's completion models accept a `connectors` parameter that can be used to
753    /// specify the data connectors used by Cohere when executing the completion
754    /// (see `examples/cohere_connectors.rs`).
755    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
756        match self.additional_params {
757            Some(params) => {
758                self.additional_params = Some(json_utils::merge(params, additional_params));
759            }
760            None => {
761                self.additional_params = Some(additional_params);
762            }
763        }
764        self
765    }
766
767    /// Sets the additional parameters for the completion request.
768    /// This can be used to set additional provider-specific parameters. For example,
769    /// Cohere's completion models accept a `connectors` parameter that can be used to
770    /// specify the data connectors used by Cohere when executing the completion
771    /// (see `examples/cohere_connectors.rs`).
772    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
773        self.additional_params = additional_params;
774        self
775    }
776
777    /// Sets the temperature for the completion request.
778    pub fn temperature(mut self, temperature: f64) -> Self {
779        self.temperature = Some(temperature);
780        self
781    }
782
783    /// Sets the temperature for the completion request.
784    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
785        self.temperature = temperature;
786        self
787    }
788
789    /// Sets the max tokens for the completion request.
790    /// Note: This is required if using Anthropic
791    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
792        self.max_tokens = Some(max_tokens);
793        self
794    }
795
796    /// Sets the max tokens for the completion request.
797    /// Note: This is required if using Anthropic
798    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
799        self.max_tokens = max_tokens;
800        self
801    }
802
803    /// Sets the thing.
804    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
805        self.tool_choice = Some(tool_choice);
806        self
807    }
808
809    /// Sets the output schema for structured output. When set, providers that support
810    /// native structured outputs will constrain the model's response to match this schema.
811    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
812    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
813    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
814    /// to still be able to leverage structured outputs.
815    pub fn output_schema(mut self, schema: schemars::Schema) -> Self {
816        self.output_schema = Some(schema);
817        self
818    }
819
820    /// Sets the output schema for structured output from an optional value.
821    /// NOTE: For direct type conversion, you may want to use `Agent::prompt_typed()` - using this method
822    /// with `Agent::prompt()` will still output a String at the end, it'll just be compatible with whatever
823    /// type you want to use here. This method is primarily an escape hatch for agents being used as tools
824    /// to still be able to leverage structured outputs.
825    pub fn output_schema_opt(mut self, schema: Option<schemars::Schema>) -> Self {
826        self.output_schema = schema;
827        self
828    }
829
830    /// Builds the completion request.
831    pub fn build(self) -> CompletionRequest {
832        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
833            .expect("There will always be atleast the prompt");
834
835        CompletionRequest {
836            model: self.request_model,
837            preamble: self.preamble,
838            chat_history,
839            documents: self.documents,
840            tools: self.tools,
841            temperature: self.temperature,
842            max_tokens: self.max_tokens,
843            tool_choice: self.tool_choice,
844            additional_params: self.additional_params,
845            output_schema: self.output_schema,
846        }
847    }
848
849    /// Sends the completion request to the completion model provider and returns the completion response.
850    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
851        let model = self.model.clone();
852        model.completion(self.build()).await
853    }
854
855    /// Stream the completion request
856    pub async fn stream<'a>(
857        self,
858    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
859    where
860        <M as CompletionModel>::StreamingResponse: 'a,
861        Self: 'a,
862    {
863        let model = self.model.clone();
864        model.stream(self.build()).await
865    }
866}
867
868#[cfg(test)]
869mod tests {
870
871    use super::*;
872
873    #[test]
874    fn test_document_display_without_metadata() {
875        let doc = Document {
876            id: "123".to_string(),
877            text: "This is a test document.".to_string(),
878            additional_props: HashMap::new(),
879        };
880
881        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
882        assert_eq!(format!("{doc}"), expected);
883    }
884
885    #[test]
886    fn test_document_display_with_metadata() {
887        let mut additional_props = HashMap::new();
888        additional_props.insert("author".to_string(), "John Doe".to_string());
889        additional_props.insert("length".to_string(), "42".to_string());
890
891        let doc = Document {
892            id: "123".to_string(),
893            text: "This is a test document.".to_string(),
894            additional_props,
895        };
896
897        let expected = concat!(
898            "<file id: 123>\n",
899            "<metadata author: \"John Doe\" length: \"42\" />\n",
900            "This is a test document.\n",
901            "</file>\n"
902        );
903        assert_eq!(format!("{doc}"), expected);
904    }
905
906    #[test]
907    fn test_normalize_documents_with_documents() {
908        let doc1 = Document {
909            id: "doc1".to_string(),
910            text: "Document 1 text.".to_string(),
911            additional_props: HashMap::new(),
912        };
913
914        let doc2 = Document {
915            id: "doc2".to_string(),
916            text: "Document 2 text.".to_string(),
917            additional_props: HashMap::new(),
918        };
919
920        let request = CompletionRequest {
921            model: None,
922            preamble: None,
923            chat_history: OneOrMany::one("What is the capital of France?".into()),
924            documents: vec![doc1, doc2],
925            tools: Vec::new(),
926            temperature: None,
927            max_tokens: None,
928            tool_choice: None,
929            additional_params: None,
930            output_schema: None,
931        };
932
933        let expected = Message::User {
934            content: OneOrMany::many(vec![
935                UserContent::document(
936                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
937                    Some(DocumentMediaType::TXT),
938                ),
939                UserContent::document(
940                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
941                    Some(DocumentMediaType::TXT),
942                ),
943            ])
944            .expect("There will be at least one document"),
945        };
946
947        assert_eq!(request.normalized_documents(), Some(expected));
948    }
949
950    #[test]
951    fn test_normalize_documents_without_documents() {
952        let request = CompletionRequest {
953            model: None,
954            preamble: None,
955            chat_history: OneOrMany::one("What is the capital of France?".into()),
956            documents: Vec::new(),
957            tools: Vec::new(),
958            temperature: None,
959            max_tokens: None,
960            tool_choice: None,
961            additional_params: None,
962            output_schema: None,
963        };
964
965        assert_eq!(request.normalized_documents(), None);
966    }
967}