Skip to main content

rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76    json_utils,
77    message::{Message, UserContent},
78    tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87// Errors
88#[derive(Debug, Error)]
89pub enum CompletionError {
90    /// Http error (e.g.: connection error, timeout, etc.)
91    #[error("HttpError: {0}")]
92    HttpError(#[from] http_client::Error),
93
94    /// Json error (e.g.: serialization, deserialization)
95    #[error("JsonError: {0}")]
96    JsonError(#[from] serde_json::Error),
97
98    /// Url error (e.g.: invalid URL)
99    #[error("UrlError: {0}")]
100    UrlError(#[from] url::ParseError),
101
102    #[cfg(not(target_family = "wasm"))]
103    /// Error building the completion request
104    #[error("RequestError: {0}")]
105    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107    #[cfg(target_family = "wasm")]
108    /// Error building the completion request
109    #[error("RequestError: {0}")]
110    RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112    /// Error parsing the completion response
113    #[error("ResponseError: {0}")]
114    ResponseError(String),
115
116    /// Error returned by the completion model provider
117    #[error("ProviderError: {0}")]
118    ProviderError(String),
119}
120
121/// Prompt errors
122#[derive(Debug, Error)]
123pub enum PromptError {
124    /// Something went wrong with the completion
125    #[error("CompletionError: {0}")]
126    CompletionError(#[from] CompletionError),
127
128    /// There was an error while using a tool
129    #[error("ToolCallError: {0}")]
130    ToolError(#[from] ToolSetError),
131
132    /// There was an issue while executing a tool on a tool server
133    #[error("ToolServerError: {0}")]
134    ToolServerError(#[from] ToolServerError),
135
136    /// The LLM tried to call too many tools during a multi-turn conversation.
137    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
138    /// or increase the amount of turns given in `.multi_turn()`.
139    #[error("MaxTurnError: (reached max turn limit: {max_turns})")]
140    MaxTurnsError {
141        max_turns: usize,
142        chat_history: Box<Vec<Message>>,
143        prompt: Box<Message>,
144    },
145
146    /// A prompting loop was cancelled.
147    #[error("PromptCancelled: {reason}")]
148    PromptCancelled {
149        chat_history: Box<Vec<Message>>,
150        reason: String,
151    },
152}
153
154impl PromptError {
155    pub(crate) fn prompt_cancelled(chat_history: Vec<Message>, reason: impl Into<String>) -> Self {
156        Self::PromptCancelled {
157            chat_history: Box::new(chat_history),
158            reason: reason.into(),
159        }
160    }
161}
162
163#[derive(Clone, Debug, Deserialize, Serialize)]
164pub struct Document {
165    pub id: String,
166    pub text: String,
167    #[serde(flatten)]
168    pub additional_props: HashMap<String, String>,
169}
170
171impl std::fmt::Display for Document {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        write!(
174            f,
175            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
176            self.id,
177            if self.additional_props.is_empty() {
178                self.text.clone()
179            } else {
180                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
181                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
182                let metadata = sorted_props
183                    .iter()
184                    .map(|(k, v)| format!("{k}: {v:?}"))
185                    .collect::<Vec<_>>()
186                    .join(" ");
187                format!("<metadata {} />\n{}", metadata, self.text)
188            }
189        )
190    }
191}
192
193#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
194pub struct ToolDefinition {
195    pub name: String,
196    pub description: String,
197    pub parameters: serde_json::Value,
198}
199
200// ================================================================
201// Implementations
202// ================================================================
203/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
204pub trait Prompt: WasmCompatSend + WasmCompatSync {
205    /// Send a simple prompt to the underlying completion model.
206    ///
207    /// If the completion model's response is a message, then it is returned as a string.
208    ///
209    /// If the completion model's response is a tool call, then the tool is called and
210    /// the result is returned as a string.
211    ///
212    /// If the tool does not exist, or the tool call fails, then an error is returned.
213    fn prompt(
214        &self,
215        prompt: impl Into<Message> + WasmCompatSend,
216    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
217}
218
219/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
220pub trait Chat: WasmCompatSend + WasmCompatSync {
221    /// Send a prompt with optional chat history to the underlying completion model.
222    ///
223    /// If the completion model's response is a message, then it is returned as a string.
224    ///
225    /// If the completion model's response is a tool call, then the tool is called and the result
226    /// is returned as a string.
227    ///
228    /// If the tool does not exist, or the tool call fails, then an error is returned.
229    fn chat(
230        &self,
231        prompt: impl Into<Message> + WasmCompatSend,
232        chat_history: Vec<Message>,
233    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
234}
235
236/// Trait defining a low-level LLM completion interface
237pub trait Completion<M: CompletionModel> {
238    /// Generates a completion request builder for the given `prompt` and `chat_history`.
239    /// This function is meant to be called by the user to further customize the
240    /// request at prompt time before sending it.
241    ///
242    /// ❗IMPORTANT: The type that implements this trait might have already
243    /// populated fields in the builder (the exact fields depend on the type).
244    /// For fields that have already been set by the model, calling the corresponding
245    /// method on the builder will overwrite the value set by the model.
246    ///
247    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
248    /// contain the `preamble` provided when creating the agent.
249    fn completion(
250        &self,
251        prompt: impl Into<Message> + WasmCompatSend,
252        chat_history: Vec<Message>,
253    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
254    + WasmCompatSend;
255}
256
257/// General completion response struct that contains the high-level completion choice
258/// and the raw response. The completion choice contains one or more assistant content.
259#[derive(Debug)]
260pub struct CompletionResponse<T> {
261    /// The completion choice (represented by one or more assistant message content)
262    /// returned by the completion model provider
263    pub choice: OneOrMany<AssistantContent>,
264    /// Tokens used during prompting and responding
265    pub usage: Usage,
266    /// The raw response returned by the completion model provider
267    pub raw_response: T,
268}
269
270/// A trait for grabbing the token usage of a completion response.
271///
272/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
273pub trait GetTokenUsage {
274    fn token_usage(&self) -> Option<crate::completion::Usage>;
275}
276
277impl GetTokenUsage for () {
278    fn token_usage(&self) -> Option<crate::completion::Usage> {
279        None
280    }
281}
282
283impl<T> GetTokenUsage for Option<T>
284where
285    T: GetTokenUsage,
286{
287    fn token_usage(&self) -> Option<crate::completion::Usage> {
288        if let Some(usage) = self {
289            usage.token_usage()
290        } else {
291            None
292        }
293    }
294}
295
296/// Struct representing the token usage for a completion request.
297/// If tokens used are `0`, then the provider failed to supply token usage metrics.
298#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
299pub struct Usage {
300    /// The number of input ("prompt") tokens used in a given request.
301    pub input_tokens: u64,
302    /// The number of output ("completion") tokens used in a given request.
303    pub output_tokens: u64,
304    /// We store this separately as some providers may only report one number
305    pub total_tokens: u64,
306    /// The number of cached input tokens (from prompt caching). 0 if not reported by provider.
307    pub cached_input_tokens: u64,
308}
309
310impl Usage {
311    /// Creates a new instance of `Usage`.
312    pub fn new() -> Self {
313        Self {
314            input_tokens: 0,
315            output_tokens: 0,
316            total_tokens: 0,
317            cached_input_tokens: 0,
318        }
319    }
320}
321
322impl Default for Usage {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328impl Add for Usage {
329    type Output = Self;
330
331    fn add(self, other: Self) -> Self::Output {
332        Self {
333            input_tokens: self.input_tokens + other.input_tokens,
334            output_tokens: self.output_tokens + other.output_tokens,
335            total_tokens: self.total_tokens + other.total_tokens,
336            cached_input_tokens: self.cached_input_tokens + other.cached_input_tokens,
337        }
338    }
339}
340
341impl AddAssign for Usage {
342    fn add_assign(&mut self, other: Self) {
343        self.input_tokens += other.input_tokens;
344        self.output_tokens += other.output_tokens;
345        self.total_tokens += other.total_tokens;
346        self.cached_input_tokens += other.cached_input_tokens;
347    }
348}
349
350/// Trait defining a completion model that can be used to generate completion responses.
351/// This trait is meant to be implemented by the user to define a custom completion model,
352/// either from a third party provider (e.g.: OpenAI) or a local model.
353pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
354    /// The raw response type returned by the underlying completion model.
355    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
356    /// The raw response type returned by the underlying completion model when streaming.
357    type StreamingResponse: Clone
358        + Unpin
359        + WasmCompatSend
360        + WasmCompatSync
361        + Serialize
362        + DeserializeOwned
363        + GetTokenUsage;
364
365    type Client;
366
367    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
368
369    /// Generates a completion response for the given completion request.
370    fn completion(
371        &self,
372        request: CompletionRequest,
373    ) -> impl std::future::Future<
374        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
375    > + WasmCompatSend;
376
377    fn stream(
378        &self,
379        request: CompletionRequest,
380    ) -> impl std::future::Future<
381        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
382    > + WasmCompatSend;
383
384    /// Generates a completion request builder for the given `prompt`.
385    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
386        CompletionRequestBuilder::new(self.clone(), prompt)
387    }
388}
389
390#[allow(deprecated)]
391#[deprecated(
392    since = "0.25.0",
393    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
394)]
395pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
396    fn completion(
397        &self,
398        request: CompletionRequest,
399    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
400
401    fn stream(
402        &self,
403        request: CompletionRequest,
404    ) -> WasmBoxedFuture<
405        '_,
406        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
407    >;
408
409    fn completion_request(
410        &self,
411        prompt: Message,
412    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
413}
414
415#[allow(deprecated)]
416impl<T, R> CompletionModelDyn for T
417where
418    T: CompletionModel<StreamingResponse = R>,
419    R: Clone + Unpin + GetTokenUsage + 'static,
420{
421    fn completion(
422        &self,
423        request: CompletionRequest,
424    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
425        Box::pin(async move {
426            self.completion(request)
427                .await
428                .map(|resp| CompletionResponse {
429                    choice: resp.choice,
430                    usage: resp.usage,
431                    raw_response: (),
432                })
433        })
434    }
435
436    fn stream(
437        &self,
438        request: CompletionRequest,
439    ) -> WasmBoxedFuture<
440        '_,
441        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
442    > {
443        Box::pin(async move {
444            let resp = self.stream(request).await?;
445            let inner = resp.inner;
446
447            let stream = streaming::StreamingResultDyn {
448                inner: Box::pin(inner),
449            };
450
451            Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
452        })
453    }
454
455    /// Generates a completion request builder for the given `prompt`.
456    fn completion_request(
457        &self,
458        prompt: Message,
459    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
460        CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
461    }
462}
463
464/// Struct representing a general completion request that can be sent to a completion model provider.
465#[derive(Debug, Clone)]
466pub struct CompletionRequest {
467    /// The preamble to be sent to the completion model provider
468    pub preamble: Option<String>,
469    /// The chat history to be sent to the completion model provider.
470    /// The very last message will always be the prompt (hence why there is *always* one)
471    pub chat_history: OneOrMany<Message>,
472    /// The documents to be sent to the completion model provider
473    pub documents: Vec<Document>,
474    /// The tools to be sent to the completion model provider
475    pub tools: Vec<ToolDefinition>,
476    /// The temperature to be sent to the completion model provider
477    pub temperature: Option<f64>,
478    /// The max tokens to be sent to the completion model provider
479    pub max_tokens: Option<u64>,
480    /// Whether tools are required to be used by the model provider or not before providing a response.
481    pub tool_choice: Option<ToolChoice>,
482    /// Additional provider-specific parameters to be sent to the completion model provider
483    pub additional_params: Option<serde_json::Value>,
484}
485
486impl CompletionRequest {
487    /// Returns documents normalized into a message (if any).
488    /// Most providers do not accept documents directly as input, so it needs to convert into a
489    ///  `Message` so that it can be incorporated into `chat_history` as a
490    pub fn normalized_documents(&self) -> Option<Message> {
491        if self.documents.is_empty() {
492            return None;
493        }
494
495        // Most providers will convert documents into a text unless it can handle document messages.
496        // We use `UserContent::document` for those who handle it directly!
497        let messages = self
498            .documents
499            .iter()
500            .map(|doc| {
501                UserContent::document(
502                    doc.to_string(),
503                    // In the future, we can customize `Document` to pass these extra types through.
504                    // Most providers ditch these but they might want to use them.
505                    Some(DocumentMediaType::TXT),
506                )
507            })
508            .collect::<Vec<_>>();
509
510        Some(Message::User {
511            content: OneOrMany::many(messages).expect("There will be atleast one document"),
512        })
513    }
514}
515
516/// Builder struct for constructing a completion request.
517///
518/// Example usage:
519/// ```rust
520/// use rig::{
521///     providers::openai::{Client, self},
522///     completion::CompletionRequestBuilder,
523/// };
524///
525/// let openai = Client::new("your-openai-api-key");
526/// let model = openai.completion_model(openai::GPT_4O).build();
527///
528/// // Create the completion request and execute it separately
529/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
530///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
531///     .temperature(0.5)
532///     .build();
533///
534/// let response = model.completion(request)
535///     .await
536///     .expect("Failed to get completion response");
537/// ```
538///
539/// Alternatively, you can execute the completion request directly from the builder:
540/// ```rust
541/// use rig::{
542///     providers::openai::{Client, self},
543///     completion::CompletionRequestBuilder,
544/// };
545///
546/// let openai = Client::new("your-openai-api-key");
547/// let model = openai.completion_model(openai::GPT_4O).build();
548///
549/// // Create the completion request and execute it directly
550/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
551///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
552///     .temperature(0.5)
553///     .send()
554///     .await
555///     .expect("Failed to get completion response");
556/// ```
557///
558/// Note: It is usually unnecessary to create a completion request builder directly.
559/// Instead, use the [CompletionModel::completion_request] method.
560pub struct CompletionRequestBuilder<M: CompletionModel> {
561    model: M,
562    prompt: Message,
563    preamble: Option<String>,
564    chat_history: Vec<Message>,
565    documents: Vec<Document>,
566    tools: Vec<ToolDefinition>,
567    temperature: Option<f64>,
568    max_tokens: Option<u64>,
569    tool_choice: Option<ToolChoice>,
570    additional_params: Option<serde_json::Value>,
571}
572
573impl<M: CompletionModel> CompletionRequestBuilder<M> {
574    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
575        Self {
576            model,
577            prompt: prompt.into(),
578            preamble: None,
579            chat_history: Vec::new(),
580            documents: Vec::new(),
581            tools: Vec::new(),
582            temperature: None,
583            max_tokens: None,
584            tool_choice: None,
585            additional_params: None,
586        }
587    }
588
589    /// Sets the preamble for the completion request.
590    pub fn preamble(mut self, preamble: String) -> Self {
591        self.preamble = Some(preamble);
592        self
593    }
594
595    pub fn without_preamble(mut self) -> Self {
596        self.preamble = None;
597        self
598    }
599
600    /// Adds a message to the chat history for the completion request.
601    pub fn message(mut self, message: Message) -> Self {
602        self.chat_history.push(message);
603        self
604    }
605
606    /// Adds a list of messages to the chat history for the completion request.
607    pub fn messages(self, messages: Vec<Message>) -> Self {
608        messages
609            .into_iter()
610            .fold(self, |builder, msg| builder.message(msg))
611    }
612
613    /// Adds a document to the completion request.
614    pub fn document(mut self, document: Document) -> Self {
615        self.documents.push(document);
616        self
617    }
618
619    /// Adds a list of documents to the completion request.
620    pub fn documents(self, documents: Vec<Document>) -> Self {
621        documents
622            .into_iter()
623            .fold(self, |builder, doc| builder.document(doc))
624    }
625
626    /// Adds a tool to the completion request.
627    pub fn tool(mut self, tool: ToolDefinition) -> Self {
628        self.tools.push(tool);
629        self
630    }
631
632    /// Adds a list of tools to the completion request.
633    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
634        tools
635            .into_iter()
636            .fold(self, |builder, tool| builder.tool(tool))
637    }
638
639    /// Adds additional parameters to the completion request.
640    /// This can be used to set additional provider-specific parameters. For example,
641    /// Cohere's completion models accept a `connectors` parameter that can be used to
642    /// specify the data connectors used by Cohere when executing the completion
643    /// (see `examples/cohere_connectors.rs`).
644    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
645        match self.additional_params {
646            Some(params) => {
647                self.additional_params = Some(json_utils::merge(params, additional_params));
648            }
649            None => {
650                self.additional_params = Some(additional_params);
651            }
652        }
653        self
654    }
655
656    /// Sets the additional parameters for the completion request.
657    /// This can be used to set additional provider-specific parameters. For example,
658    /// Cohere's completion models accept a `connectors` parameter that can be used to
659    /// specify the data connectors used by Cohere when executing the completion
660    /// (see `examples/cohere_connectors.rs`).
661    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
662        self.additional_params = additional_params;
663        self
664    }
665
666    /// Sets the temperature for the completion request.
667    pub fn temperature(mut self, temperature: f64) -> Self {
668        self.temperature = Some(temperature);
669        self
670    }
671
672    /// Sets the temperature for the completion request.
673    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
674        self.temperature = temperature;
675        self
676    }
677
678    /// Sets the max tokens for the completion request.
679    /// Note: This is required if using Anthropic
680    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
681        self.max_tokens = Some(max_tokens);
682        self
683    }
684
685    /// Sets the max tokens for the completion request.
686    /// Note: This is required if using Anthropic
687    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
688        self.max_tokens = max_tokens;
689        self
690    }
691
692    /// Sets the thing.
693    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
694        self.tool_choice = Some(tool_choice);
695        self
696    }
697
698    /// Builds the completion request.
699    pub fn build(self) -> CompletionRequest {
700        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
701            .expect("There will always be atleast the prompt");
702
703        CompletionRequest {
704            preamble: self.preamble,
705            chat_history,
706            documents: self.documents,
707            tools: self.tools,
708            temperature: self.temperature,
709            max_tokens: self.max_tokens,
710            tool_choice: self.tool_choice,
711            additional_params: self.additional_params,
712        }
713    }
714
715    /// Sends the completion request to the completion model provider and returns the completion response.
716    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
717        let model = self.model.clone();
718        model.completion(self.build()).await
719    }
720
721    /// Stream the completion request
722    pub async fn stream<'a>(
723        self,
724    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
725    where
726        <M as CompletionModel>::StreamingResponse: 'a,
727        Self: 'a,
728    {
729        let model = self.model.clone();
730        model.stream(self.build()).await
731    }
732}
733
734#[cfg(test)]
735mod tests {
736
737    use super::*;
738
739    #[test]
740    fn test_document_display_without_metadata() {
741        let doc = Document {
742            id: "123".to_string(),
743            text: "This is a test document.".to_string(),
744            additional_props: HashMap::new(),
745        };
746
747        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
748        assert_eq!(format!("{doc}"), expected);
749    }
750
751    #[test]
752    fn test_document_display_with_metadata() {
753        let mut additional_props = HashMap::new();
754        additional_props.insert("author".to_string(), "John Doe".to_string());
755        additional_props.insert("length".to_string(), "42".to_string());
756
757        let doc = Document {
758            id: "123".to_string(),
759            text: "This is a test document.".to_string(),
760            additional_props,
761        };
762
763        let expected = concat!(
764            "<file id: 123>\n",
765            "<metadata author: \"John Doe\" length: \"42\" />\n",
766            "This is a test document.\n",
767            "</file>\n"
768        );
769        assert_eq!(format!("{doc}"), expected);
770    }
771
772    #[test]
773    fn test_normalize_documents_with_documents() {
774        let doc1 = Document {
775            id: "doc1".to_string(),
776            text: "Document 1 text.".to_string(),
777            additional_props: HashMap::new(),
778        };
779
780        let doc2 = Document {
781            id: "doc2".to_string(),
782            text: "Document 2 text.".to_string(),
783            additional_props: HashMap::new(),
784        };
785
786        let request = CompletionRequest {
787            preamble: None,
788            chat_history: OneOrMany::one("What is the capital of France?".into()),
789            documents: vec![doc1, doc2],
790            tools: Vec::new(),
791            temperature: None,
792            max_tokens: None,
793            tool_choice: None,
794            additional_params: None,
795        };
796
797        let expected = Message::User {
798            content: OneOrMany::many(vec![
799                UserContent::document(
800                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
801                    Some(DocumentMediaType::TXT),
802                ),
803                UserContent::document(
804                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
805                    Some(DocumentMediaType::TXT),
806                ),
807            ])
808            .expect("There will be at least one document"),
809        };
810
811        assert_eq!(request.normalized_documents(), Some(expected));
812    }
813
814    #[test]
815    fn test_normalize_documents_without_documents() {
816        let request = CompletionRequest {
817            preamble: None,
818            chat_history: OneOrMany::one("What is the capital of France?".into()),
819            documents: Vec::new(),
820            tools: Vec::new(),
821            temperature: None,
822            max_tokens: None,
823            tool_choice: None,
824            additional_params: None,
825        };
826
827        assert_eq!(request.normalized_documents(), None);
828    }
829}