rig/completion/
request.rs

1//! This module provides functionality for working with completion models.
2//! It provides traits, structs, and enums for generating completion requests,
3//! handling completion responses, and defining completion models.
4//!
5//! The main traits defined in this module are:
6//! - [Prompt]: Defines a high-level LLM one-shot prompt interface.
7//! - [Chat]: Defines a high-level LLM chat interface with chat history.
8//! - [Completion]: Defines a low-level LLM completion interface for generating completion requests.
9//! - [CompletionModel]: Defines a completion model that can be used to generate completion
10//!   responses from requests.
11//!
12//! The [Prompt] and [Chat] traits are high level traits that users are expected to use
13//! to interact with LLM models. Moreover, it is good practice to implement one of these
14//! traits for composite agents that use multiple LLM models to generate responses.
15//!
16//! The [Completion] trait defines a lower level interface that is useful when the user want
17//! to further customize the request before sending it to the completion model provider.
18//!
19//! The [CompletionModel] trait is meant to act as the interface between providers and
20//! the library. It defines the methods that need to be implemented by the user to define
21//! a custom base completion model (i.e.: a private or third party LLM provider).
22//!
23//! The module also provides various structs and enums for representing generic completion requests,
24//! responses, and errors.
25//!
26//! Example Usage:
27//! ```rust
28//! use rig::providers::openai::{Client, self};
29//! use rig::completion::*;
30//!
31//! // Initialize the OpenAI client and a completion model
32//! let openai = Client::new("your-openai-api-key");
33//!
34//! let gpt_4 = openai.completion_model(openai::GPT_4);
35//!
36//! // Create the completion request
37//! let request = gpt_4.completion_request("Who are you?")
38//!     .preamble("\
39//!         You are Marvin, an extremely smart but depressed robot who is \
40//!         nonetheless helpful towards humanity.\
41//!     ")
42//!     .temperature(0.5)
43//!     .build();
44//!
45//! // Send the completion request and get the completion response
46//! let response = gpt_4.completion(request)
47//!     .await
48//!     .expect("Failed to get completion response");
49//!
50//! // Handle the completion response
51//! match completion_response.choice {
52//!     ModelChoice::Message(message) => {
53//!         // Handle the completion response as a message
54//!         println!("Received message: {}", message);
55//!     }
56//!     ModelChoice::ToolCall(tool_name, tool_params) => {
57//!         // Handle the completion response as a tool call
58//!         println!("Received tool call: {} {:?}", tool_name, tool_params);
59//!     }
60//! }
61//! ```
62//!
63//! For more information on how to use the completion functionality, refer to the documentation of
64//! the individual traits, structs, and enums defined in this module.
65
66use super::message::{AssistantContent, DocumentMediaType};
67use crate::client::FinalCompletionResponse;
68#[allow(deprecated)]
69use crate::client::completion::CompletionModelHandle;
70use crate::message::ToolChoice;
71use crate::streaming::StreamingCompletionResponse;
72use crate::tool::server::ToolServerError;
73use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
74use crate::{OneOrMany, http_client, streaming};
75use crate::{
76    json_utils,
77    message::{Message, UserContent},
78    tool::ToolSetError,
79};
80use serde::de::DeserializeOwned;
81use serde::{Deserialize, Serialize};
82use std::collections::HashMap;
83use std::ops::{Add, AddAssign};
84use std::sync::Arc;
85use thiserror::Error;
86
87// Errors
88#[derive(Debug, Error)]
89pub enum CompletionError {
90    /// Http error (e.g.: connection error, timeout, etc.)
91    #[error("HttpError: {0}")]
92    HttpError(#[from] http_client::Error),
93
94    /// Json error (e.g.: serialization, deserialization)
95    #[error("JsonError: {0}")]
96    JsonError(#[from] serde_json::Error),
97
98    /// Url error (e.g.: invalid URL)
99    #[error("UrlError: {0}")]
100    UrlError(#[from] url::ParseError),
101
102    #[cfg(not(target_family = "wasm"))]
103    /// Error building the completion request
104    #[error("RequestError: {0}")]
105    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
106
107    #[cfg(target_family = "wasm")]
108    /// Error building the completion request
109    #[error("RequestError: {0}")]
110    RequestError(#[from] Box<dyn std::error::Error + 'static>),
111
112    /// Error parsing the completion response
113    #[error("ResponseError: {0}")]
114    ResponseError(String),
115
116    /// Error returned by the completion model provider
117    #[error("ProviderError: {0}")]
118    ProviderError(String),
119}
120
121/// Prompt errors
122#[derive(Debug, Error)]
123pub enum PromptError {
124    /// Something went wrong with the completion
125    #[error("CompletionError: {0}")]
126    CompletionError(#[from] CompletionError),
127
128    /// There was an error while using a tool
129    #[error("ToolCallError: {0}")]
130    ToolError(#[from] ToolSetError),
131
132    /// There was an issue while executing a tool on a tool server
133    #[error("ToolServerError: {0}")]
134    ToolServerError(#[from] ToolServerError),
135
136    /// The LLM tried to call too many tools during a multi-turn conversation.
137    /// To fix this, you may either need to lower the amount of tools your model has access to (and then create other agents to share the tool load)
138    /// or increase the amount of turns given in `.multi_turn()`.
139    #[error("MaxDepthError: (reached limit: {max_depth})")]
140    MaxDepthError {
141        max_depth: usize,
142        chat_history: Box<Vec<Message>>,
143        prompt: Message,
144    },
145
146    /// A prompting loop was cancelled.
147    #[error("PromptCancelled")]
148    PromptCancelled { chat_history: Box<Vec<Message>> },
149}
150
151impl PromptError {
152    pub(crate) fn prompt_cancelled(chat_history: Vec<Message>) -> Self {
153        Self::PromptCancelled {
154            chat_history: Box::new(chat_history),
155        }
156    }
157}
158
159#[derive(Clone, Debug, Deserialize, Serialize)]
160pub struct Document {
161    pub id: String,
162    pub text: String,
163    #[serde(flatten)]
164    pub additional_props: HashMap<String, String>,
165}
166
167impl std::fmt::Display for Document {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        write!(
170            f,
171            concat!("<file id: {}>\n", "{}\n", "</file>\n"),
172            self.id,
173            if self.additional_props.is_empty() {
174                self.text.clone()
175            } else {
176                let mut sorted_props = self.additional_props.iter().collect::<Vec<_>>();
177                sorted_props.sort_by(|a, b| a.0.cmp(b.0));
178                let metadata = sorted_props
179                    .iter()
180                    .map(|(k, v)| format!("{k}: {v:?}"))
181                    .collect::<Vec<_>>()
182                    .join(" ");
183                format!("<metadata {} />\n{}", metadata, self.text)
184            }
185        )
186    }
187}
188
189#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
190pub struct ToolDefinition {
191    pub name: String,
192    pub description: String,
193    pub parameters: serde_json::Value,
194}
195
196// ================================================================
197// Implementations
198// ================================================================
199/// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out).
200pub trait Prompt: WasmCompatSend + WasmCompatSync {
201    /// Send a simple prompt to the underlying completion model.
202    ///
203    /// If the completion model's response is a message, then it is returned as a string.
204    ///
205    /// If the completion model's response is a tool call, then the tool is called and
206    /// the result is returned as a string.
207    ///
208    /// If the tool does not exist, or the tool call fails, then an error is returned.
209    fn prompt(
210        &self,
211        prompt: impl Into<Message> + WasmCompatSend,
212    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
213}
214
215/// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out).
216pub trait Chat: WasmCompatSend + WasmCompatSync {
217    /// Send a prompt with optional chat history to the underlying completion model.
218    ///
219    /// If the completion model's response is a message, then it is returned as a string.
220    ///
221    /// If the completion model's response is a tool call, then the tool is called and the result
222    /// is returned as a string.
223    ///
224    /// If the tool does not exist, or the tool call fails, then an error is returned.
225    fn chat(
226        &self,
227        prompt: impl Into<Message> + WasmCompatSend,
228        chat_history: Vec<Message>,
229    ) -> impl std::future::IntoFuture<Output = Result<String, PromptError>, IntoFuture: WasmCompatSend>;
230}
231
232/// Trait defining a low-level LLM completion interface
233pub trait Completion<M: CompletionModel> {
234    /// Generates a completion request builder for the given `prompt` and `chat_history`.
235    /// This function is meant to be called by the user to further customize the
236    /// request at prompt time before sending it.
237    ///
238    /// ❗IMPORTANT: The type that implements this trait might have already
239    /// populated fields in the builder (the exact fields depend on the type).
240    /// For fields that have already been set by the model, calling the corresponding
241    /// method on the builder will overwrite the value set by the model.
242    ///
243    /// For example, the request builder returned by [`Agent::completion`](crate::agent::Agent::completion) will already
244    /// contain the `preamble` provided when creating the agent.
245    fn completion(
246        &self,
247        prompt: impl Into<Message> + WasmCompatSend,
248        chat_history: Vec<Message>,
249    ) -> impl std::future::Future<Output = Result<CompletionRequestBuilder<M>, CompletionError>>
250    + WasmCompatSend;
251}
252
253/// General completion response struct that contains the high-level completion choice
254/// and the raw response. The completion choice contains one or more assistant content.
255#[derive(Debug)]
256pub struct CompletionResponse<T> {
257    /// The completion choice (represented by one or more assistant message content)
258    /// returned by the completion model provider
259    pub choice: OneOrMany<AssistantContent>,
260    /// Tokens used during prompting and responding
261    pub usage: Usage,
262    /// The raw response returned by the completion model provider
263    pub raw_response: T,
264}
265
266/// A trait for grabbing the token usage of a completion response.
267///
268/// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do.
269pub trait GetTokenUsage {
270    fn token_usage(&self) -> Option<crate::completion::Usage>;
271}
272
273impl GetTokenUsage for () {
274    fn token_usage(&self) -> Option<crate::completion::Usage> {
275        None
276    }
277}
278
279impl<T> GetTokenUsage for Option<T>
280where
281    T: GetTokenUsage,
282{
283    fn token_usage(&self) -> Option<crate::completion::Usage> {
284        if let Some(usage) = self {
285            usage.token_usage()
286        } else {
287            None
288        }
289    }
290}
291
292/// Struct representing the token usage for a completion request.
293/// If tokens used are `0`, then the provider failed to supply token usage metrics.
294#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
295pub struct Usage {
296    /// The number of input ("prompt") tokens used in a given request.
297    pub input_tokens: u64,
298    /// The number of output ("completion") tokens used in a given request.
299    pub output_tokens: u64,
300    /// We store this separately as some providers may only report one number
301    pub total_tokens: u64,
302}
303
304impl Usage {
305    /// Creates a new instance of `Usage`.
306    pub fn new() -> Self {
307        Self {
308            input_tokens: 0,
309            output_tokens: 0,
310            total_tokens: 0,
311        }
312    }
313}
314
315impl Default for Usage {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321impl Add for Usage {
322    type Output = Self;
323
324    fn add(self, other: Self) -> Self::Output {
325        Self {
326            input_tokens: self.input_tokens + other.input_tokens,
327            output_tokens: self.output_tokens + other.output_tokens,
328            total_tokens: self.total_tokens + other.total_tokens,
329        }
330    }
331}
332
333impl AddAssign for Usage {
334    fn add_assign(&mut self, other: Self) {
335        self.input_tokens += other.input_tokens;
336        self.output_tokens += other.output_tokens;
337        self.total_tokens += other.total_tokens;
338    }
339}
340
341/// Trait defining a completion model that can be used to generate completion responses.
342/// This trait is meant to be implemented by the user to define a custom completion model,
343/// either from a third party provider (e.g.: OpenAI) or a local model.
344pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync {
345    /// The raw response type returned by the underlying completion model.
346    type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned;
347    /// The raw response type returned by the underlying completion model when streaming.
348    type StreamingResponse: Clone
349        + Unpin
350        + WasmCompatSend
351        + WasmCompatSync
352        + Serialize
353        + DeserializeOwned
354        + GetTokenUsage;
355
356    type Client;
357
358    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
359
360    /// Generates a completion response for the given completion request.
361    fn completion(
362        &self,
363        request: CompletionRequest,
364    ) -> impl std::future::Future<
365        Output = Result<CompletionResponse<Self::Response>, CompletionError>,
366    > + WasmCompatSend;
367
368    fn stream(
369        &self,
370        request: CompletionRequest,
371    ) -> impl std::future::Future<
372        Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>,
373    > + WasmCompatSend;
374
375    /// Generates a completion request builder for the given `prompt`.
376    fn completion_request(&self, prompt: impl Into<Message>) -> CompletionRequestBuilder<Self> {
377        CompletionRequestBuilder::new(self.clone(), prompt)
378    }
379}
380
381#[allow(deprecated)]
382#[deprecated(
383    since = "0.25.0",
384    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `CompletionModel` instead."
385)]
386pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync {
387    fn completion(
388        &self,
389        request: CompletionRequest,
390    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>>;
391
392    fn stream(
393        &self,
394        request: CompletionRequest,
395    ) -> WasmBoxedFuture<
396        '_,
397        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
398    >;
399
400    fn completion_request(
401        &self,
402        prompt: Message,
403    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>>;
404}
405
406#[allow(deprecated)]
407impl<T, R> CompletionModelDyn for T
408where
409    T: CompletionModel<StreamingResponse = R>,
410    R: Clone + Unpin + GetTokenUsage + 'static,
411{
412    fn completion(
413        &self,
414        request: CompletionRequest,
415    ) -> WasmBoxedFuture<'_, Result<CompletionResponse<()>, CompletionError>> {
416        Box::pin(async move {
417            self.completion(request)
418                .await
419                .map(|resp| CompletionResponse {
420                    choice: resp.choice,
421                    usage: resp.usage,
422                    raw_response: (),
423                })
424        })
425    }
426
427    fn stream(
428        &self,
429        request: CompletionRequest,
430    ) -> WasmBoxedFuture<
431        '_,
432        Result<StreamingCompletionResponse<FinalCompletionResponse>, CompletionError>,
433    > {
434        Box::pin(async move {
435            let resp = self.stream(request).await?;
436            let inner = resp.inner;
437
438            let stream = streaming::StreamingResultDyn {
439                inner: Box::pin(inner),
440            };
441
442            Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
443        })
444    }
445
446    /// Generates a completion request builder for the given `prompt`.
447    fn completion_request(
448        &self,
449        prompt: Message,
450    ) -> CompletionRequestBuilder<CompletionModelHandle<'_>> {
451        CompletionRequestBuilder::new(CompletionModelHandle::new(Arc::new(self.clone())), prompt)
452    }
453}
454
455/// Struct representing a general completion request that can be sent to a completion model provider.
456#[derive(Debug, Clone)]
457pub struct CompletionRequest {
458    /// The preamble to be sent to the completion model provider
459    pub preamble: Option<String>,
460    /// The chat history to be sent to the completion model provider.
461    /// The very last message will always be the prompt (hence why there is *always* one)
462    pub chat_history: OneOrMany<Message>,
463    /// The documents to be sent to the completion model provider
464    pub documents: Vec<Document>,
465    /// The tools to be sent to the completion model provider
466    pub tools: Vec<ToolDefinition>,
467    /// The temperature to be sent to the completion model provider
468    pub temperature: Option<f64>,
469    /// The max tokens to be sent to the completion model provider
470    pub max_tokens: Option<u64>,
471    /// Whether tools are required to be used by the model provider or not before providing a response.
472    pub tool_choice: Option<ToolChoice>,
473    /// Additional provider-specific parameters to be sent to the completion model provider
474    pub additional_params: Option<serde_json::Value>,
475}
476
477impl CompletionRequest {
478    /// Returns documents normalized into a message (if any).
479    /// Most providers do not accept documents directly as input, so it needs to convert into a
480    ///  `Message` so that it can be incorporated into `chat_history` as a
481    pub fn normalized_documents(&self) -> Option<Message> {
482        if self.documents.is_empty() {
483            return None;
484        }
485
486        // Most providers will convert documents into a text unless it can handle document messages.
487        // We use `UserContent::document` for those who handle it directly!
488        let messages = self
489            .documents
490            .iter()
491            .map(|doc| {
492                UserContent::document(
493                    doc.to_string(),
494                    // In the future, we can customize `Document` to pass these extra types through.
495                    // Most providers ditch these but they might want to use them.
496                    Some(DocumentMediaType::TXT),
497                )
498            })
499            .collect::<Vec<_>>();
500
501        Some(Message::User {
502            content: OneOrMany::many(messages).expect("There will be atleast one document"),
503        })
504    }
505}
506
507/// Builder struct for constructing a completion request.
508///
509/// Example usage:
510/// ```rust
511/// use rig::{
512///     providers::openai::{Client, self},
513///     completion::CompletionRequestBuilder,
514/// };
515///
516/// let openai = Client::new("your-openai-api-key");
517/// let model = openai.completion_model(openai::GPT_4O).build();
518///
519/// // Create the completion request and execute it separately
520/// let request = CompletionRequestBuilder::new(model, "Who are you?".to_string())
521///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
522///     .temperature(0.5)
523///     .build();
524///
525/// let response = model.completion(request)
526///     .await
527///     .expect("Failed to get completion response");
528/// ```
529///
530/// Alternatively, you can execute the completion request directly from the builder:
531/// ```rust
532/// use rig::{
533///     providers::openai::{Client, self},
534///     completion::CompletionRequestBuilder,
535/// };
536///
537/// let openai = Client::new("your-openai-api-key");
538/// let model = openai.completion_model(openai::GPT_4O).build();
539///
540/// // Create the completion request and execute it directly
541/// let response = CompletionRequestBuilder::new(model, "Who are you?".to_string())
542///     .preamble("You are Marvin from the Hitchhiker's Guide to the Galaxy.".to_string())
543///     .temperature(0.5)
544///     .send()
545///     .await
546///     .expect("Failed to get completion response");
547/// ```
548///
549/// Note: It is usually unnecessary to create a completion request builder directly.
550/// Instead, use the [CompletionModel::completion_request] method.
551pub struct CompletionRequestBuilder<M: CompletionModel> {
552    model: M,
553    prompt: Message,
554    preamble: Option<String>,
555    chat_history: Vec<Message>,
556    documents: Vec<Document>,
557    tools: Vec<ToolDefinition>,
558    temperature: Option<f64>,
559    max_tokens: Option<u64>,
560    tool_choice: Option<ToolChoice>,
561    additional_params: Option<serde_json::Value>,
562}
563
564impl<M: CompletionModel> CompletionRequestBuilder<M> {
565    pub fn new(model: M, prompt: impl Into<Message>) -> Self {
566        Self {
567            model,
568            prompt: prompt.into(),
569            preamble: None,
570            chat_history: Vec::new(),
571            documents: Vec::new(),
572            tools: Vec::new(),
573            temperature: None,
574            max_tokens: None,
575            tool_choice: None,
576            additional_params: None,
577        }
578    }
579
580    /// Sets the preamble for the completion request.
581    pub fn preamble(mut self, preamble: String) -> Self {
582        self.preamble = Some(preamble);
583        self
584    }
585
586    pub fn without_preamble(mut self) -> Self {
587        self.preamble = None;
588        self
589    }
590
591    /// Adds a message to the chat history for the completion request.
592    pub fn message(mut self, message: Message) -> Self {
593        self.chat_history.push(message);
594        self
595    }
596
597    /// Adds a list of messages to the chat history for the completion request.
598    pub fn messages(self, messages: Vec<Message>) -> Self {
599        messages
600            .into_iter()
601            .fold(self, |builder, msg| builder.message(msg))
602    }
603
604    /// Adds a document to the completion request.
605    pub fn document(mut self, document: Document) -> Self {
606        self.documents.push(document);
607        self
608    }
609
610    /// Adds a list of documents to the completion request.
611    pub fn documents(self, documents: Vec<Document>) -> Self {
612        documents
613            .into_iter()
614            .fold(self, |builder, doc| builder.document(doc))
615    }
616
617    /// Adds a tool to the completion request.
618    pub fn tool(mut self, tool: ToolDefinition) -> Self {
619        self.tools.push(tool);
620        self
621    }
622
623    /// Adds a list of tools to the completion request.
624    pub fn tools(self, tools: Vec<ToolDefinition>) -> Self {
625        tools
626            .into_iter()
627            .fold(self, |builder, tool| builder.tool(tool))
628    }
629
630    /// Adds additional parameters to the completion request.
631    /// This can be used to set additional provider-specific parameters. For example,
632    /// Cohere's completion models accept a `connectors` parameter that can be used to
633    /// specify the data connectors used by Cohere when executing the completion
634    /// (see `examples/cohere_connectors.rs`).
635    pub fn additional_params(mut self, additional_params: serde_json::Value) -> Self {
636        match self.additional_params {
637            Some(params) => {
638                self.additional_params = Some(json_utils::merge(params, additional_params));
639            }
640            None => {
641                self.additional_params = Some(additional_params);
642            }
643        }
644        self
645    }
646
647    /// Sets the additional parameters for the completion request.
648    /// This can be used to set additional provider-specific parameters. For example,
649    /// Cohere's completion models accept a `connectors` parameter that can be used to
650    /// specify the data connectors used by Cohere when executing the completion
651    /// (see `examples/cohere_connectors.rs`).
652    pub fn additional_params_opt(mut self, additional_params: Option<serde_json::Value>) -> Self {
653        self.additional_params = additional_params;
654        self
655    }
656
657    /// Sets the temperature for the completion request.
658    pub fn temperature(mut self, temperature: f64) -> Self {
659        self.temperature = Some(temperature);
660        self
661    }
662
663    /// Sets the temperature for the completion request.
664    pub fn temperature_opt(mut self, temperature: Option<f64>) -> Self {
665        self.temperature = temperature;
666        self
667    }
668
669    /// Sets the max tokens for the completion request.
670    /// Note: This is required if using Anthropic
671    pub fn max_tokens(mut self, max_tokens: u64) -> Self {
672        self.max_tokens = Some(max_tokens);
673        self
674    }
675
676    /// Sets the max tokens for the completion request.
677    /// Note: This is required if using Anthropic
678    pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
679        self.max_tokens = max_tokens;
680        self
681    }
682
683    /// Sets the thing.
684    pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
685        self.tool_choice = Some(tool_choice);
686        self
687    }
688
689    /// Builds the completion request.
690    pub fn build(self) -> CompletionRequest {
691        let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat())
692            .expect("There will always be atleast the prompt");
693
694        CompletionRequest {
695            preamble: self.preamble,
696            chat_history,
697            documents: self.documents,
698            tools: self.tools,
699            temperature: self.temperature,
700            max_tokens: self.max_tokens,
701            tool_choice: self.tool_choice,
702            additional_params: self.additional_params,
703        }
704    }
705
706    /// Sends the completion request to the completion model provider and returns the completion response.
707    pub async fn send(self) -> Result<CompletionResponse<M::Response>, CompletionError> {
708        let model = self.model.clone();
709        model.completion(self.build()).await
710    }
711
712    /// Stream the completion request
713    pub async fn stream<'a>(
714        self,
715    ) -> Result<StreamingCompletionResponse<M::StreamingResponse>, CompletionError>
716    where
717        <M as CompletionModel>::StreamingResponse: 'a,
718        Self: 'a,
719    {
720        let model = self.model.clone();
721        model.stream(self.build()).await
722    }
723}
724
725#[cfg(test)]
726mod tests {
727
728    use super::*;
729
730    #[test]
731    fn test_document_display_without_metadata() {
732        let doc = Document {
733            id: "123".to_string(),
734            text: "This is a test document.".to_string(),
735            additional_props: HashMap::new(),
736        };
737
738        let expected = "<file id: 123>\nThis is a test document.\n</file>\n";
739        assert_eq!(format!("{doc}"), expected);
740    }
741
742    #[test]
743    fn test_document_display_with_metadata() {
744        let mut additional_props = HashMap::new();
745        additional_props.insert("author".to_string(), "John Doe".to_string());
746        additional_props.insert("length".to_string(), "42".to_string());
747
748        let doc = Document {
749            id: "123".to_string(),
750            text: "This is a test document.".to_string(),
751            additional_props,
752        };
753
754        let expected = concat!(
755            "<file id: 123>\n",
756            "<metadata author: \"John Doe\" length: \"42\" />\n",
757            "This is a test document.\n",
758            "</file>\n"
759        );
760        assert_eq!(format!("{doc}"), expected);
761    }
762
763    #[test]
764    fn test_normalize_documents_with_documents() {
765        let doc1 = Document {
766            id: "doc1".to_string(),
767            text: "Document 1 text.".to_string(),
768            additional_props: HashMap::new(),
769        };
770
771        let doc2 = Document {
772            id: "doc2".to_string(),
773            text: "Document 2 text.".to_string(),
774            additional_props: HashMap::new(),
775        };
776
777        let request = CompletionRequest {
778            preamble: None,
779            chat_history: OneOrMany::one("What is the capital of France?".into()),
780            documents: vec![doc1, doc2],
781            tools: Vec::new(),
782            temperature: None,
783            max_tokens: None,
784            tool_choice: None,
785            additional_params: None,
786        };
787
788        let expected = Message::User {
789            content: OneOrMany::many(vec![
790                UserContent::document(
791                    "<file id: doc1>\nDocument 1 text.\n</file>\n".to_string(),
792                    Some(DocumentMediaType::TXT),
793                ),
794                UserContent::document(
795                    "<file id: doc2>\nDocument 2 text.\n</file>\n".to_string(),
796                    Some(DocumentMediaType::TXT),
797                ),
798            ])
799            .expect("There will be at least one document"),
800        };
801
802        assert_eq!(request.normalized_documents(), Some(expected));
803    }
804
805    #[test]
806    fn test_normalize_documents_without_documents() {
807        let request = CompletionRequest {
808            preamble: None,
809            chat_history: OneOrMany::one("What is the capital of France?".into()),
810            documents: Vec::new(),
811            tools: Vec::new(),
812            temperature: None,
813            max_tokens: None,
814            tool_choice: None,
815            additional_params: None,
816        };
817
818        assert_eq!(request.normalized_documents(), None);
819    }
820}